Model.Predict
Click here to go back to the reference.
def _Eval(model,
data_loader,
save_path_prediction: str):
""""""
if save_path_prediction is not None:
assert save_path_prediction[-4:] == '.csv'
cell_encoder, drug_encoder, fusion_module = None, None, None
if type(model) == tuple:
cell_encoder, drug_encoder, fusion_module = model
cell_encoder = cell_encoder.to(device)
drug_encoder = drug_encoder.to(device)
fusion_module = fusion_module.to(device)
else:
model = model.to(device)
loss_func = nn.MSELoss()
real_pre_dict = dict()
print('Start prediction!')
start = time.time()
if type(model) == tuple:
epoch_loss, real, pre, cell_ls, drug_ls = _InferenceTP(cell_encoder, drug_encoder, fusion_module, data_loader, loss_func)
else:
epoch_loss, real, pre, cell_ls, drug_ls = _Inference(model, data_loader, loss_func)
if save_path_prediction is not None:
real_pre_dict['cell'] = cell_ls
real_pre_dict['drug'] = drug_ls
real_pre_dict['pre'] = pre
dataframe = pd.DataFrame(real_pre_dict)
dataframe.to_csv(save_path_prediction, index=False, sep=',')
end = time.time()
print('Time consumed {:.6f}'.format(end - start))
print('Prediction completed!')
return pre, cell_ls, drug_ls
def Predict(model,
data,
save_path_prediction: str = None):
""""""
for i in range(len(data)):
data.pair_ls[i] += [0.0] if len(data.pair_ls[i]) == 2 else []
data_loader = DrDataLoader(DrDataset(data), batch_size=64, shuffle=False)
return _Eval(model, data_loader, save_path_prediction)