Model.Train
Click here to view source code.
def Train(model, epochs: int, lr: float, train_loader, val_loader=None, test_loader=None,
loss_func=None, optimizer=None, ratio: list = None, classify: bool = False,
save_path_prediction: str = None, save_path_model: str = None, save_path_log: str = None,
no_wandb: bool = True, project=None, name=None, config=None, early_stop=None):
It can be used to train the drug response prediction model.
PARAMETERS:
model - The model built by
Model.DrModel.epochs (int) - The number of epochs.
lr (float) - The learning rate.
train_loader - The train loader got by
Data.DrDataLoader.val_loader (optional) - The val loader got by
Data.DrDataLoader. (default: None)test_loader (optional) - The test loader got by
Data.DrDataLoader. (default: None)loss_func (optional) - The loss function. When the value is
None,torch.nn.MSELossis used ifclassify=False, andtorch.nn.BCEWithLogitsLoss()is used ifclassify=True. (default: None)optimizer (optional) - The optimizer. The
torch.optim.Adamis used when the value isNone.ratio (list, optional) - The learning rate ratio of cell encoder, drug encoder and fusion module. The
[1, 1, 1]is used when the value isNone. (default: None)classify (bool, optional) - Whether classification task. (default: False)
save_path_prediction (str, optional) - Save path of predictions. It is expected to end in
".csv".save_path_model (str, optional) - Save path of trained model. It is expected to end in
".pkl".save_path_log (str, optional) - Save path of training log. It is expected to end in
".txt".no_wandb (bool, optional) - Whether not to use wandb. (default: True)
project (optional) - Parameter of wandb. (default: None)
name (optional) - Parameter of wandb. (default: None)
config (optional) - Parameter of wandb. (default: None)
early_stop (optional) - Early stop detection dataset.
"val","test", orNoneis available.
OUTPUTS:
model - The trained model.
loss_func - The loss function.
optimizer - The optimizer.
val_epoch_loss_ls (list) - The value of loss on validation set.
test_epoch_loss_ls (list) - The value of loss on test set.