CellEncoder.CNN
Click here to go back to the reference.
class CNN(nn.Module):
def __init__(self, in_dim: int, ft_dim: int = 735, hid_channel_ls: list = None, kernel_size_conv: int = 7,
stride_conv: int = 1, padding_conv: int = 0, kernel_size_pool: int = 3, stride_pool: int = 3,
padding_pool: int = 0, batch_norm: bool = True, max_pool: bool = True, flatten: bool = True,
debug: bool = False):
"""tCNNS: let batch_norm=False"""
super(CNN, self).__init__()
self.batch_norm = batch_norm
self.max_pool = max_pool
self.flatten = flatten
self.debug = debug
if hid_channel_ls is None:
hid_channel_ls = [40, 80, 60]
channel_ls = [1] + hid_channel_ls
self.input = nn.Linear(in_dim, ft_dim)
self.conv = nn.ModuleList([nn.Conv1d(in_channels=channel_ls[i], out_channels=channel_ls[i + 1],
kernel_size=kernel_size_conv, stride=stride_conv, padding=padding_conv)
for i in range(len(channel_ls) - 1)])
if self.batch_norm:
self.norm = nn.ModuleList([nn.BatchNorm1d(channel_ls[i + 1]) for i in range(len(channel_ls) - 1)])
if self.max_pool:
self.pool = nn.ModuleList([nn.MaxPool1d(kernel_size=kernel_size_pool, stride=stride_pool,
padding=padding_pool) for _ in range(len(channel_ls) - 1)])
if self.flatten:
self.flat = nn.Flatten()
def forward(self, f):
f = F.relu(self.input(f))
f = torch.unsqueeze(f, dim=1)
for i in range(len(self.conv)):
f = self.conv[i](f)
if self.batch_norm:
f = self.norm[i](f)
f = F.relu(f)
if self.max_pool:
f = self.pool[i](f)
if self.flatten is False:
f_mean = torch.mean(f, dim=1)
f_max, _ = torch.max(f, dim=1)
f_mix, _ = torch.min(f, dim=1)
f = f_mean + f_max + f_mix
else:
f = self.flat(f)
if self.debug:
print(f.shape)
return f