Data.DrDataset

Click here to go back to the reference.

class DrDataset(Dataset, ABC):
    """"""

    def __init__(self, dr_data: DrData,
                 radius: int = 2,
                 nBits: int = 512,
                 max_len: int = 230,
                 char_dict: dict = None,
                 right: bool = True,
                 mpg: bool = True):

        super().__init__()
        self._pair_ls = dr_data.pair_ls
        self._cell_ft = dr_data.cell_ft
        self._drug_ft = dr_data.drug_ft
        self._radius = radius
        self._nBits = nBits
        self._max_len = max_len
        self._char_dict = char_dict
        self._right = right
        self._SMILES_dict = dr_data.smiles_dict
        self._MPG_dict = dr_data.mpg_dict
        self._MPG = mpg
        self._data = DrDataset.preprocess(self)

    def __getitem__(self, idx):
        data = self._data[idx]
        return data

    def __len__(self):
        return len(self._data)

    def preprocess(self):

        if type(self._cell_ft) == str:
            assert self._cell_ft in ['EXP', 'PES', 'MUT', 'CNV']
            cell_dict = joblib.load(os.path.join(os.path.split(__file__)[0], 'DefaultData/GDSC_{}.pkl'.format(self._cell_ft)))
        else:
            cell_dict = self._cell_ft

        if type(self._drug_ft) == str:
            assert self._drug_ft in ['ECFP', 'SMILES', 'Graph', 'Image']
            if self._SMILES_dict is None:
                drug_dict = joblib.load(os.path.join(os.path.split(__file__)[0], 'DefaultData/SMILES_dict.pkl'))
            else:
                drug_dict = self._SMILES_dict
        else:
            drug_dict = self._drug_ft

        if self._MPG_dict is None:
            self._MPG_dict = joblib.load(os.path.join(os.path.split(__file__)[0], 'DefaultData/MPG_dict.pkl'))

        data = []
        for i in tqdm(range(len(self._pair_ls))):
            each_pair = self._pair_ls[i]
            if type(cell_dict[each_pair[0]]) == torch.Tensor:
                cell_ft = cell_dict[each_pair[0]]
            else:
                cell_ft = torch.tensor(cell_dict[each_pair[0]], dtype=torch.float32)
            cell_ft = (cell_ft - cell_ft.mean()) / cell_ft.std(dim=0)

            if self._drug_ft == 'ECFP':
                drug_ft = PreEcfp(drug_dict[each_pair[1]], self._radius, self._nBits)
            elif self._drug_ft == 'SMILES':
                drug_ft = PreSmiles(drug_dict[each_pair[1]], self._max_len, self._char_dict, self._right)
            elif self._drug_ft == 'Graph':
                drug_ft = PreGraph(drug_dict[each_pair[1]])
                if self._MPG:
                    try:
                        drug_ft = _Add_seg_id(_Self_loop(Data(x=drug_ft.x, edge_index=drug_ft.edge_index,
                                                              edge_attr=drug_ft.edge_attr,
                                                              mpg_ft=self._MPG_dict[each_pair[1]])))
                    except KeyError:
                        print('MPG feature missing! Set MPG=False or run Data.DrRead.FeatDrug')
            elif self._drug_ft == 'Image':
                drug_ft = ImageDataset([drug_dict[each_pair[1]]])[0]
            else:
                if type(drug_dict[each_pair[1]]) == torch.Tensor:
                    drug_ft = drug_dict[each_pair[1]]
                else:
                    drug_ft = torch.tensor(drug_dict[each_pair[1]], dtype=torch.float32)

            data.append(Data(cell_ft=cell_ft, drug_ft=drug_ft,
                             response=torch.tensor([each_pair[2]], dtype=torch.float32),
                             cell_name=each_pair[0], drug_name=each_pair[1]))
        return data