CellEncoder.DNN
Click here to go back to the reference.
class DNN(nn.Module):
def __init__(self, in_dim: int, ft_dim: int, hid_dim: int = 100, num_layers: int = 2, dropout: float = 0.3):
"""hid_dim, num_layers"""
super(DNN, self).__init__()
assert num_layers >= 1
dim_ls = [in_dim] + [hid_dim] * (num_layers - 1) + [ft_dim]
self.encode_dnn = nn.ModuleList([nn.Linear(dim_ls[i], dim_ls[i + 1]) for i in range(num_layers - 1)])
self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(num_layers - 1)])
self.output = nn.Linear(dim_ls[-2], dim_ls[-1])
def forward(self, f):
for i in range(len(self.encode_dnn)):
f = F.relu(self.encode_dnn[i](f))
f = self.dropout[i](f)
f = self.output(f)
return f