Model.Train

Click here to go back to the reference.

def _TrainingTP(cell_encoder, drug_encoder, fusion_module, data_loader, loss_func, optimizer):
    """"""
    cell_encoder.train()
    drug_encoder.train()
    fusion_module.train()
    real = []
    pre = []
    cell_ls = []
    drug_ls = []
    epoch_loss = 0
    length = 0
    for it, (cell_ft, drug_ft, response, cell_name, drug_name) in enumerate(data_loader):
        cell_ls += cell_name
        drug_ls += drug_name
        cell_ft, drug_ft, response = cell_ft.to(device), drug_ft.to(device), response.to(device)
        prediction, cell_ft, drug_ft = fusion_module(cell_encoder(cell_ft), drug_encoder(drug_ft))
        loss = loss_func(prediction, response)
        real += torch.squeeze(response).cpu().tolist()
        pre += torch.squeeze(prediction).cpu().tolist()
        response = torch.squeeze(response).cpu().tolist()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        length += len(response)
        epoch_loss += loss.detach().item() * len(response)
    epoch_loss /= length
    return cell_encoder, drug_encoder, fusion_module, optimizer, epoch_loss, real, pre, cell_ls, drug_ls


def _InferenceTP(cell_encoder, drug_encoder, fusion_module, data_loader, loss_func):
    """"""
    cell_encoder.eval()
    drug_encoder.eval()
    fusion_module.eval()
    with torch.no_grad():
        real = []
        pre = []
        cell_ls = []
        drug_ls = []
        epoch_loss = 0
        length = 0
        for it, (cell_ft, drug_ft, response, cell_name, drug_name) in enumerate(data_loader):
            cell_ls += cell_name
            drug_ls += drug_name
            cell_ft, drug_ft, response = cell_ft.to(device), drug_ft.to(device), response.to(device)
            prediction, cell_ft, drug_ft = fusion_module(cell_encoder(cell_ft), drug_encoder(drug_ft))
            loss = loss_func(prediction, response)
            real += torch.squeeze(response).cpu().tolist()
            pre += torch.squeeze(prediction).cpu().tolist()
            response = torch.squeeze(response).cpu().tolist()
            length += len(response)
            epoch_loss += loss.detach().item() * len(response)
        epoch_loss /= length
    return epoch_loss, real, pre, cell_ls, drug_ls


def _TrainTP(model_tp, epochs: int, lr: float, train_loader, val_loader=None, test_loader=None,
             loss_func=None, optimizer=None, ratio: list = None, classify: bool = False,
             save_path_prediction: str = None, save_path_model: str = None, save_path_log: str = None,
             no_wandb: bool = True, project=None, name=None, config=None, early_stop=None):
    """"""
    if save_path_prediction is not None:
        assert save_path_prediction[-4:] == '.csv'
    if save_path_model is not None:
        assert save_path_model[-4:] == '.pkl'
    if save_path_log is not None:
        assert save_path_log[-4:] == '.txt'
    if ratio is None:
        ratio = [1, 1, 1]
    assert len(ratio) == 3
    assert early_stop in ['val', 'test', None]
    cell_encoder, drug_encoder, fusion_module = model_tp
    cell_encoder = cell_encoder.to(device)
    drug_encoder = drug_encoder.to(device)
    fusion_module = fusion_module.to(device)
    if loss_func is None:
        if not classify:
            loss_func = nn.MSELoss()
        else:
            loss_func = nn.BCEWithLogitsLoss()
    if optimizer is None:
        params = [
            {'params': cell_encoder.parameters(), 'lr': ratio[0] * lr},
            {'params': drug_encoder.parameters(), 'lr': ratio[1] * lr},
            {'params': fusion_module.parameters(), 'lr': ratio[2] * lr}
        ]
        optimizer = optim.Adam(params, lr=lr)
    real_pre_val_dict = dict()
    real_pre_test_dict = dict()
    val_epoch_loss_ls = []
    test_epoch_loss_ls = []
    if not no_wandb:
        wandb.init(
            project=project,
            name=name,
            config=config,
            mode="disabled" if no_wandb else "online"
        )
    print('Start training!')
    for epoch in range(epochs):
        start = time.time()
        loss_dict = dict()

        cell_encoder, drug_encoder, fusion_module, optimizer, epoch_loss, real, pre, cell_ls, drug_ls = _TrainingTP(cell_encoder, drug_encoder, fusion_module, train_loader, loss_func, optimizer)
        if not classify:
            print('Epoch {}, train loss {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch, epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)))
        else:
            print('Epoch {}, train loss {:.6f}'.format(epoch, epoch_loss))
        if save_path_log is not None:
            with open(save_path_log, 'a') as file:
                if not classify:
                    print('Epoch {}, train loss {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch, epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)), file=file)
                else:
                    print('Epoch {}, train loss {:.6f}'.format(epoch, epoch_loss), file=file)
        loss_dict['train_loss'] = epoch_loss

        if val_loader is not None:
            epoch_loss, real, pre, cell_ls, drug_ls = _InferenceTP(cell_encoder, drug_encoder, fusion_module, val_loader, loss_func)
            if not classify:
                print('         val loss   {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)))
            else:
                print('         val loss   {:.6f}'.format(epoch_loss))
            if save_path_log is not None:
                with open(save_path_log, 'a') as file:
                    if not classify:
                        print('         val loss   {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)), file=file)
                    else:
                        print('         val loss   {:.6f}'.format(epoch_loss), file=file)
            val_epoch_loss_ls.append(epoch_loss)
            if save_path_prediction is not None:
                if epoch == 0:
                    real_pre_val_dict['cell'] = cell_ls
                    real_pre_val_dict['drug'] = drug_ls
                    real_pre_val_dict['real'] = real
                real_pre_val_dict['epoch_{}'.format(epoch)] = pre
                dataframe = pd.DataFrame(real_pre_val_dict)
                dataframe.to_csv(save_path_prediction[:-4] + '_val' + save_path_prediction[-4:], index=False, sep=',')
            loss_dict['val_loss'] = epoch_loss

        if test_loader is not None:
            epoch_loss, real, pre, cell_ls, drug_ls = _InferenceTP(cell_encoder, drug_encoder, fusion_module, test_loader, loss_func)
            if not classify:
                print('         test loss  {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)))
            else:
                print('         test loss  {:.6f}'.format(epoch_loss))
            if save_path_log is not None:
                with open(save_path_log, 'a') as file:
                    if not classify:
                        print('         test loss  {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)), file=file)
                    else:
                        print('         test loss  {:.6f}'.format(epoch_loss), file=file)
            test_epoch_loss_ls.append(epoch_loss)
            if save_path_prediction is not None:
                if epoch == 0:
                    real_pre_test_dict['cell'] = cell_ls
                    real_pre_test_dict['drug'] = drug_ls
                    real_pre_test_dict['real'] = real
                real_pre_test_dict['epoch_{}'.format(epoch)] = pre
                dataframe = pd.DataFrame(real_pre_test_dict)
                dataframe.to_csv(save_path_prediction[:-4] + '_test' + save_path_prediction[-4:], index=False, sep=',')
            loss_dict['test_loss'] = epoch_loss

        if epoch >= (epochs // 2):
            if early_stop == 'val' and val_epoch_loss_ls[-1] == min(val_epoch_loss_ls[epochs // 2:]):
                trained_model = (cell_encoder.cpu(), drug_encoder.cpu(), fusion_module.cpu())
                if save_path_model is not None:
                    joblib.dump(trained_model, save_path_model)
                cell_encoder = cell_encoder.to(device)
                drug_encoder = drug_encoder.to(device)
                fusion_module = fusion_module.to(device)
            elif early_stop == 'test' and test_epoch_loss_ls[-1] == min(test_epoch_loss_ls[epochs // 2:]):
                trained_model = (cell_encoder.cpu(), drug_encoder.cpu(), fusion_module.cpu())
                if save_path_model is not None:
                    joblib.dump(trained_model, save_path_model)
                cell_encoder = cell_encoder.to(device)
                drug_encoder = drug_encoder.to(device)
                fusion_module = fusion_module.to(device)
            else:
                trained_model = (cell_encoder.cpu(), drug_encoder.cpu(), fusion_module.cpu())
                if save_path_model is not None:
                    joblib.dump(trained_model, save_path_model)
                cell_encoder = cell_encoder.to(device)
                drug_encoder = drug_encoder.to(device)
                fusion_module = fusion_module.to(device)

        if not no_wandb:
            wandb.log(loss_dict)

        end = time.time()
        print('         time consumed {:.6f}'.format(end - start))
        if save_path_log is not None:
            with open(save_path_log, 'a') as file:
                print('         time consumed {:.6f}'.format(end - start), file=file)

    if not no_wandb:
        wandb.finish()
    print('Training completed!')
    return trained_model, loss_func, optimizer, val_epoch_loss_ls, test_epoch_loss_ls


def _Training(model, data_loader, loss_func, optimizer):
    """"""
    model.train()
    real = []
    pre = []
    cell_ls = []
    drug_ls = []
    epoch_loss = 0
    length = 0
    for it, (cell_ft, drug_ft, response, cell_name, drug_name) in enumerate(data_loader):
        cell_ls += cell_name
        drug_ls += drug_name
        cell_ft, drug_ft, response = cell_ft.to(device), drug_ft.to(device), response.to(device)
        prediction, cell_ft, drug_ft = model(cell_ft, drug_ft)
        loss = loss_func(prediction, response)
        real += torch.squeeze(response).cpu().tolist()
        pre += torch.squeeze(prediction).cpu().tolist()
        response = torch.squeeze(response).cpu().tolist()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        length += len(response)
        epoch_loss += loss.detach().item() * len(response)
    epoch_loss /= length
    return model, optimizer, epoch_loss, real, pre, cell_ls, drug_ls


def _Inference(model, data_loader, loss_func):
    """"""
    model.eval()
    with torch.no_grad():
        real = []
        pre = []
        cell_ls = []
        drug_ls = []
        epoch_loss = 0
        length = 0
        for it, (cell_ft, drug_ft, response, cell_name, drug_name) in enumerate(data_loader):
            cell_ls += cell_name
            drug_ls += drug_name
            cell_ft, drug_ft, response = cell_ft.to(device), drug_ft.to(device), response.to(device)
            prediction, cell_ft, drug_ft = model(cell_ft, drug_ft)
            loss = loss_func(prediction, response)
            real += torch.squeeze(response).cpu().tolist()
            pre += torch.squeeze(prediction).cpu().tolist()
            response = torch.squeeze(response).cpu().tolist()
            length += len(response)
            epoch_loss += loss.detach().item() * len(response)
        epoch_loss /= length
    return epoch_loss, real, pre, cell_ls, drug_ls


def _Train(model, epochs: int, lr: float, train_loader, val_loader=None, test_loader=None,
           loss_func=None, optimizer=None, classify: bool = False,
           save_path_prediction: str = None, save_path_model: str = None, save_path_log: str = None,
           no_wandb: bool = True, project=None, name=None, config=None, early_stop=None):
    """"""
    if save_path_prediction is not None:
        assert save_path_prediction[-4:] == '.csv'
    if save_path_model is not None:
        assert save_path_model[-4:] == '.pkl'
    if save_path_log is not None:
        assert save_path_log[-4:] == '.txt'
    assert early_stop in ['val', 'test', None]
    model = model.to(device)
    if loss_func is None:
        if not classify:
            loss_func = nn.MSELoss()
        else:
            loss_func = nn.BCEWithLogitsLoss()
    if optimizer is None:
        params = [
            {'params': model.parameters(), 'lr': lr}
        ]
        optimizer = optim.Adam(params, lr=lr)
    real_pre_val_dict = dict()
    real_pre_test_dict = dict()
    val_epoch_loss_ls = []
    test_epoch_loss_ls = []
    if not no_wandb:
        wandb.init(
            project=project,
            name=name,
            config=config,
            mode="disabled" if no_wandb else "online"
        )
    print('Start training!')
    for epoch in range(epochs):
        start = time.time()
        loss_dict = dict()

        model, optimizer, epoch_loss, real, pre, cell_ls, drug_ls = _Training(model, train_loader, loss_func, optimizer)
        if not classify:
            print('Epoch {}, train loss {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch, epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)))
        else:
            print('Epoch {}, train loss {:.6f}'.format(epoch, epoch_loss))
        if save_path_log is not None:
            with open(save_path_log, 'a') as file:
                if not classify:
                    print('Epoch {}, train loss {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch, epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)), file=file)
                else:
                    print('Epoch {}, train loss {:.6f}'.format(epoch, epoch_loss), file=file)
        loss_dict['train_loss'] = epoch_loss

        if val_loader is not None:
            epoch_loss, real, pre, cell_ls, drug_ls = _Inference(model, val_loader, loss_func)
            if not classify:
                print('         val loss   {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)))
            else:
                print('         val loss   {:.6f}'.format(epoch_loss))
            if save_path_log is not None:
                with open(save_path_log, 'a') as file:
                    if not classify:
                        print('         val loss   {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)), file=file)
                    else:
                        print('         val loss   {:.6f}'.format(epoch_loss), file=file)
            val_epoch_loss_ls.append(epoch_loss)
            if save_path_prediction is not None:
                if epoch == 0:
                    real_pre_val_dict['cell'] = cell_ls
                    real_pre_val_dict['drug'] = drug_ls
                    real_pre_val_dict['real'] = real
                real_pre_val_dict['epoch_{}'.format(epoch)] = pre
                dataframe = pd.DataFrame(real_pre_val_dict)
                dataframe.to_csv(save_path_prediction[:-4] + '_val' + save_path_prediction[-4:], index=False, sep=',')
            loss_dict['val_loss'] = epoch_loss

        if test_loader is not None:
            epoch_loss, real, pre, cell_ls, drug_ls = _Inference(model, test_loader, loss_func)
            if not classify:
                print('         test loss  {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)))
            else:
                print('         test loss  {:.6f}'.format(epoch_loss))
            if save_path_log is not None:
                with open(save_path_log, 'a') as file:
                    if not classify:
                        print('         test loss  {:.6f}  r2 {:.4f}  pcc {:.4f}'.format(epoch_loss, Metric.R2(real, pre), Metric.PCC(real, pre)), file=file)
                    else:
                        print('         test loss  {:.6f}'.format(epoch_loss), file=file)
            test_epoch_loss_ls.append(epoch_loss)
            if save_path_prediction is not None:
                if epoch == 0:
                    real_pre_test_dict['cell'] = cell_ls
                    real_pre_test_dict['drug'] = drug_ls
                    real_pre_test_dict['real'] = real
                real_pre_test_dict['epoch_{}'.format(epoch)] = pre
                dataframe = pd.DataFrame(real_pre_test_dict)
                dataframe.to_csv(save_path_prediction[:-4] + '_test' + save_path_prediction[-4:], index=False, sep=',')
            loss_dict['test_loss'] = epoch_loss

        if epoch >= (epochs // 2):
            if early_stop == 'val' and val_epoch_loss_ls[-1] == min(val_epoch_loss_ls[epochs // 2:]):
                trained_model = model.cpu()
                if save_path_model is not None:
                    joblib.dump(trained_model, save_path_model)
                model = model.to(device)
            elif early_stop == 'test' and test_epoch_loss_ls[-1] == min(test_epoch_loss_ls[epochs // 2:]):
                trained_model = model.cpu()
                if save_path_model is not None:
                    joblib.dump(trained_model, save_path_model)
                model = model.to(device)
            else:
                trained_model = model.cpu()
                if save_path_model is not None:
                    joblib.dump(trained_model, save_path_model)
                model = model.to(device)

        if not no_wandb:
            wandb.log(loss_dict)

        end = time.time()
        print('         time consumed {:.6f}'.format(end - start))
        if save_path_log is not None:
            with open(save_path_log, 'a') as file:
                print('         time consumed {:.6f}'.format(end - start), file=file)

    if not no_wandb:
        wandb.finish()
    print('Training completed!')
    return trained_model, loss_func, optimizer, val_epoch_loss_ls, test_epoch_loss_ls


def Train(model, epochs: int, lr: float, train_loader, val_loader=None, test_loader=None,
          loss_func=None, optimizer=None, ratio: list = None, classify: bool = False,
          save_path_prediction: str = None, save_path_model: str = None, save_path_log: str = None,
          no_wandb: bool = True, project=None, name=None, config=None, early_stop=None):
    if type(model) == tuple:
        model, loss_func, optimizer, val_epoch_loss_ls, test_epoch_loss_ls = _TrainTP(model, epochs, lr, train_loader, val_loader, test_loader, loss_func, optimizer, ratio, classify, save_path_prediction, save_path_model, save_path_log, no_wandb, project, name, config, early_stop)
        return model, loss_func, optimizer, val_epoch_loss_ls, test_epoch_loss_ls
    else:
        model, loss_func, optimizer, val_epoch_loss_ls, test_epoch_loss_ls = _Train(model, epochs, lr, train_loader, val_loader, test_loader, loss_func, optimizer, classify, save_path_prediction, save_path_model, save_path_log, no_wandb, project, name, config, early_stop)
        return model, loss_func, optimizer, val_epoch_loss_ls, test_epoch_loss_ls