Data.DrCollate

Click here to go back to the reference.

class DrCollate:
    """"""

    def __init__(self, follow_batch=None,
                 exclude_keys=None):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch):
        cell_ft = torch.stack([g.cell_ft for g in batch])
        if type(batch[0].drug_ft) == torch.Tensor:
            drug_ft = torch.stack([g.drug_ft for g in batch])
        else:
            drug_ft = Batch.from_data_list([g.drug_ft for g in batch], self.follow_batch, self.exclude_keys)
        response = torch.stack([g.response for g in batch])
        return cell_ft, drug_ft, response, [g.cell_name for g in batch], [g.drug_name for g in batch]