DrugEncoder.MPG

Click here to go back to the reference.

class MPG(nn.Module):
    def __init__(self, ft_dim: int = 768, MPG_dim: int = 768, freeze: bool = True, conv: bool = True,
                 num_layer=5, emb_dim=768, heads=12, num_message_passing=3, drop_ratio=0, pt_path=None):
        """"""
        super(MPG, self).__init__()
        if freeze is False:
            assert pt_path is not None
        self.freeze = freeze
        self.conv = conv
        if self.freeze is not True:
            self.net = MolGNet(num_layer=num_layer, emb_dim=emb_dim, heads=heads,
                               num_message_passing=num_message_passing, drop_ratio=drop_ratio)
            self.net.load_state_dict(torch.load(pt_path))
        if self.conv:
            self.output = GCNConv(MPG_dim, ft_dim)

    def forward(self, g):
        if self.freeze:
            x = g.mpg_ft
        else:
            x = self.net(g)
        if self.conv:
            x = self.output(x, g.edge_index)
        return x, g