DrugEncoder.TrimNet

Click here to go back to the reference.

class TrimNet(nn.Module):
    def __init__(self, x_num_embedding: int = _x_num_embedding, edge_num_embedding: int = _edge_num_embedding,
                 embedding_dim: int = _drug_dim, ft_dim: int = 2, num_heads: int = 4, dropout: float = 0.1,
                 hid_dim: int = 32, depth: int = 3):
        """"""
        super(TrimNet, self).__init__()
        self.dropout = dropout
        self.x_embedding = nn.Embedding(x_num_embedding, embedding_dim)
        self.edge_embedding = nn.Embedding(edge_num_embedding, embedding_dim)
        self.reset_parameters()
        self.lin0 = nn.Linear(embedding_dim, hid_dim)
        self.convs = nn.ModuleList([Block(hid_dim, embedding_dim, num_heads) for _ in range(depth)])
        self.set2set = Set2Set(hid_dim, processing_steps=3)
        self.out = nn.Sequential(nn.Linear(2 * hid_dim, 512), nn.LayerNorm(512), nn.ReLU(inplace=True),
                                 nn.Dropout(p=self.dropout), nn.Linear(512, ft_dim))

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.x_embedding.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding.weight.data)

    def forward(self, g):
        x, edge_index, edge_attr, batch = g.x, g.edge_index, g.edge_attr, g.batch
        x = self.x_embedding(x).sum(1)
        edge_attr = self.edge_embedding(edge_attr).sum(1)
        x = F.celu(self.lin0(x))
        for conv in self.convs:
            x = x + F.dropout(conv(x, edge_index, edge_attr), p=self.dropout, training=self.training)
        x = self.set2set(x, batch)
        x = self.out(F.dropout(x, p=self.dropout, training=self.training))
        return x