Model.DrModel

Click here to go back to the reference.

class _Model(nn.Module):
    def __init__(self, cell_encoder,
                 drug_encoder,
                 fusion_module,
                 cell_encoder_pt_path: str,
                 drug_encoder_pt_path: str):
        """"""

        super(_Model, self).__init__()
        self.CellEncoder = cell_encoder
        self.DrugEncoder = drug_encoder
        self.FusionModule = fusion_module
        if cell_encoder_pt_path is not None:
            self.CellEncoder.load_state_dict(torch.load(cell_encoder_pt_path))
        if drug_encoder_pt_path is not None:
            self.DrugEncoder.load_state_dict(torch.load(drug_encoder_pt_path))

    def forward(self, cell_ft, drug_ft):
        cell_ft = self.CellEncoder(cell_ft)
        drug_ft = self.DrugEncoder(drug_ft)
        res, cell_ft, drug_ft = self.FusionModule(cell_ft, drug_ft)
        return res, cell_ft, drug_ft


def _ModelTP(cell_encoder,
             drug_encoder,
             fusion_module,
             cell_encoder_pt_path: str,
             drug_encoder_pt_path: str):
    """"""

    if cell_encoder_pt_path is not None:
        cell_encoder.load_state_dict(torch.load(cell_encoder_pt_path))
    if drug_encoder_pt_path is not None:
        drug_encoder.load_state_dict(torch.load(drug_encoder_pt_path))
    return cell_encoder, drug_encoder, fusion_module


def DrModel(cell_encoder,
            drug_encoder,
            fusion_module,
            integrate: bool = True,
            cell_encoder_pt_path: str = None,
            drug_encoder_pt_path: str = None):

    if integrate:
        model = _Model(cell_encoder, drug_encoder, fusion_module, cell_encoder_pt_path, drug_encoder_pt_path)
    else:
        model = _ModelTP(cell_encoder, drug_encoder, fusion_module, cell_encoder_pt_path, drug_encoder_pt_path)
    return model