FusionModule.MHA

Click here to go back to the reference.

class MHA(nn.Module):
    def __init__(self, cell_dim: int, drug_dim: int, hid_dim_ls: list = None, dropout: float = _dropout,
                 num_heads: int = _num_heads, mix_pool: bool = True, concat: bool = True, classify: bool = False):
        """hid_dim_ls"""
        super(MHA, self).__init__()
        self.classify = classify
        pool = 'mix' if mix_pool else 'attention'
        self.encode_dnn = DNN(cell_dim, drug_dim, hid_dim_ls, dropout, num_heads, pool, concat)

    def forward(self, f, x):
        f, cell_ft, drug_ft = self.encode_dnn(f, x)
        if self.classify:
            f = F.sigmoid(f)
        return f, cell_ft, drug_ft