Source code for context_builder.embedding

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class EmbeddingOneHot(nn.Module): """Embedder using simple one hot encoding."""
[docs] def __init__(self, input_size): """Embedder using simple one hot encoding. Parameters ---------- input_size : int Maximum number of inputs to one_hot encode """ super().__init__() self.input_size = input_size self.embedding_dim = input_size
[docs] def forward(self, X): """Create one-hot encoding of input Parameters ---------- X : torch.Tensor of shape=(n_samples,) Input to encode. Returns ------- result : torch.Tensor of shape=(n_samples, input_size) One-hot encoded version of input """ return F.one_hot(X, self.input_size).to(torch.float)