FusionModule.DNN
Click here to go back to the reference.
class DNN(nn.Module):
def __init__(self, cell_dim: int, drug_dim: int, hid_dim_ls: list = None, dropout: float = _dropout,
num_heads: int = _num_heads, pool: str = 'mean', concat: bool = True, classify: bool = False):
""""""
super(DNN, self).__init__()
if hid_dim_ls is None:
hid_dim_ls = [512, 256, 128]
self.pool = pool
self.concat = concat
self.classify = classify
assert self.pool in ['attention', 'mean', 'max', 'mix']
if self.pool in ['attention', 'mix']:
self.attention = _MHA(cell_dim=cell_dim, drug_dim=drug_dim, dropout=dropout, num_heads=num_heads)
if not self.concat:
dim_ls = [drug_dim] + hid_dim_ls + [1]
self.input_cell = nn.Linear(cell_dim, dim_ls[0])
self.input_drug = nn.Linear(drug_dim, dim_ls[0])
else:
dim_ls = [cell_dim + drug_dim] + hid_dim_ls + [1]
self.input_cell = nn.Linear(cell_dim, cell_dim)
self.input_drug = nn.Linear(drug_dim, drug_dim)
self.encode_dnn = nn.ModuleList([nn.Linear(dim_ls[i], dim_ls[i + 1]) for i in range(len(dim_ls) - 2)])
self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(len(dim_ls) - 2)])
self.output = nn.Linear(dim_ls[-2], dim_ls[-1])
def forward(self, f, x):
if type(x) == torch.Tensor:
if len(x.shape) == 3:
if self.pool == 'mean':
x = torch.mean(x, dim=2)
elif self.pool == 'max':
x, _ = torch.max(x, dim=2)
elif self.pool == 'attention':
x = self.attention(f, x)
else:
x_m, _ = torch.max(x, dim=2)
x = torch.mean(x, dim=2) + x_m + self.attention(f, x)
else:
x, g = x
if self.pool == 'mean':
x = global_mean_pool(x, g.batch)
elif self.pool == 'max':
x = global_max_pool(x, g.batch)
elif self.pool == 'attention':
x = self.attention(f, (x, g))
else:
x = global_mean_pool(x, g.batch) + global_max_pool(x, g.batch) + self.attention(f, (x, g))
cell_ft, drug_ft = f, x
f = F.relu(self.input_cell(f))
x = F.relu(self.input_drug(x))
if not self.concat:
f = f + x
else:
f = torch.cat((f, x), dim=1)
for i in range(len(self.encode_dnn)):
f = F.relu(self.encode_dnn[i](f))
f = self.dropout[i](f)
f = self.output(f)
if self.classify:
f = F.sigmoid(f)
return f, cell_ft, drug_ft