FusionModule.MHA

Click here to view source code.

class MHA(cell_dim: int, drug_dim: int, hid_dim_ls: list = None, dropout: float = 0.3, num_heads: int = 8,
mix_pool: bool = True, concat: bool = True, classify: bool = False)

It can be used for fusion cell and drug feature to predict drug response.

PARAMETERS:

  • cell_dim (int) - Dimension of cell feature.

  • drug_dim (int) - Dimension of drug feature.

  • hid_dim_ls (int, optional) - Dimension of hidden layers of DNN. (default: None)

    When the value is None, [512, 256, 128] is used.

  • dropout (float, optional) - The dropout rate of DNN. (default: 0.3)

  • num_heads (int, optional) - The number of heads of the multi-head attention network. (default: 8)

  • mix_pool (bool, optional) - The pooling method of drug feature. (default: True)

    If the value is True, the method is ‘mix’ (mean, max, and attention). If pool is False, the method is ‘attention’.

  • concat (bool, optional) - Whether to concat for fusion. (default: False)

    Cell and drug feature are added instead of concat when the value is True.

  • classify (bool, optional) - Whether classification task. (default: False)

forward(f: torch.Tensor, x)

INPUTS:

  • f (torch.Tensor) - The cell feature. [batch_size, cell_dim]

  • x (torch.Tensor) - The drug feature.

    If the drug feature is ‘ECFP’ or ‘Image’, its shape is [batch_size, drug_dim]. If the drug feature is ‘SMILES’, its shape is [batch_size, drug_dim, seq_len]. If the drug feature is ‘Graph’, its format is (encoded_graph.x, graph).

OUTPUTS:

  • f (torch.Tensor) - The predicted drug response. [batch_size, 1]

  • cell_ft (torch.Tensor) - The cell feature. [batch_size, cell_dim]

  • drug_ft (torch.Tensor) - The drug feature. [batch_size, drug_dim]