DrugEncoder.LSTM
Click here to go back to the reference.
class LSTM(nn.Module):
def __init__(self, num_embedding: int = _num_embedding, embedding_dim: int = _drug_dim, ft_dim: int = _drug_dim,
dropout: float = _dropout, bidirectional: bool = _bidirectional, num_layers: int = 2):
"""num_layers"""
super(LSTM, self).__init__()
assert num_layers >= 1
assert ft_dim % 2 == 0
self.embedding = nn.Embedding(num_embedding, embedding_dim)
self.encode_lstm = torch.nn.LSTM(input_size=embedding_dim, hidden_size=ft_dim // (2 if bidirectional else 1),
num_layers=num_layers, dropout=dropout, bidirectional=bidirectional,
batch_first=True)
def forward(self, f):
f = self.embedding(f)
f, _ = self.encode_lstm(f)
f = f.permute(0, 2, 1).contiguous()
return f