DrugEncoder.LSTM

Click here to go back to the reference.

class LSTM(nn.Module):
    def __init__(self, num_embedding: int = _num_embedding, embedding_dim: int = _drug_dim, ft_dim: int = _drug_dim,
                 dropout: float = _dropout, bidirectional: bool = _bidirectional, num_layers: int = 2):
        """num_layers"""
        super(LSTM, self).__init__()
        assert num_layers >= 1
        assert ft_dim % 2 == 0
        self.embedding = nn.Embedding(num_embedding, embedding_dim)
        self.encode_lstm = torch.nn.LSTM(input_size=embedding_dim, hidden_size=ft_dim // (2 if bidirectional else 1),
                                         num_layers=num_layers, dropout=dropout, bidirectional=bidirectional,
                                         batch_first=True)

    def forward(self, f):
        f = self.embedding(f)
        f, _ = self.encode_lstm(f)
        f = f.permute(0, 2, 1).contiguous()
        return f