Source code for context_builder.encoders

import torch
import torch.nn as nn

[docs]class Encoder(nn.Module):
[docs] def __init__(self, embedding, hidden_size, num_layers=1, bidirectional=False, LSTM=False): """Encoder part for encoding sequences. Parameters ---------- embedding : nn.Embedding Embedding layer to use hidden_size : int Size of hidden dimension num_layers : int, default=1 Number of recurrent layers to use bidirectional : boolean, default=False If True, use bidirectional recurrent layer LSTM : boolean, default=False If True, use LSTM instead of GRU """ # Call super super(Encoder, self).__init__() # Set hidden size self.hidden_size = hidden_size # Set number of layers self.num_layers = num_layers # Set bidirectional self.bidirectional = bidirectional # Set LSTM self.LSTM = LSTM # Initialise layers self.embedding = embedding self.recurrent = (nn.LSTM if LSTM else nn.GRU)( input_size = self.embedding.embedding_dim, hidden_size = self.hidden_size, num_layers = self.num_layers, batch_first = True, bidirectional = bidirectional ) # Set embedding dimension self.embedding_size = self.embedding.embedding_dim
[docs] def forward(self, input, hidden=None): """Encode data Parameters ---------- input : torch.Tensor Tensor to use as input hidden : torch.Tensor Tensor to use as hidden input (for storing sequences) Returns ------- output : torch.Tensor Output tensor hidden : torch.Tensor Hidden state to supply to next input """ # Initialise hidden if not given if hidden is None: hidden = self.initHidden(input) # Get input as embedding embedded = self.embedding(input) # Pass through recurrent layer _, hidden = self.recurrent(embedded, hidden) if self.bidirectional: embedded = torch.cat((embedded, embedded), dim=2) # Return result return embedded, hidden
def initHidden(self, X): """Create initial hidden vector for data X Parameters ---------- X : torch.Tensor, Tensor for which to create a hidden state Returns ------- result : torch.Tensor Initial hidden tensor """ if self.LSTM: return (torch.zeros(self.num_layers*(1 + int(self.bidirectional)), X.shape[0], self.hidden_size, device=X.device), torch.zeros(self.num_layers*(1 + int(self.bidirectional)), X.shape[0], self.hidden_size, device=X.device)) else: return torch.zeros(self.num_layers*(1 + int(self.bidirectional)), X.shape[0], self.hidden_size, device=X.device)