DrugEncoder.CNN

Click here to go back to the reference.

class CNN(nn.Module):
    def __init__(self, embedding: bool = True, num_embedding: int = _num_embedding, embedding_dim: int = 735,
                 hid_channel_ls: list = None, kernel_size_conv: int = 7, stride_conv: int = 1, padding_conv: int = 0,
                 kernel_size_pool: int = 3, stride_pool: int = 3, padding_pool: int = 0, batch_norm: bool = True,
                 max_pool: bool = True, flatten: bool = True, debug: bool = False):
        """tCNNS: let embedding=False batch_norm=False"""
        super(CNN, self).__init__()
        self.embedding = embedding
        self.num_embedding = num_embedding
        self.batch_norm = batch_norm
        self.max_pool = max_pool
        self.flatten = flatten
        self.debug = debug

        if hid_channel_ls is None:
            hid_channel_ls = [40, 80, 60]
        channel_ls = [embedding_dim if self.embedding else num_embedding] + hid_channel_ls

        if self.embedding:
            self.embed = nn.Embedding(num_embedding, embedding_dim)
        self.conv = nn.ModuleList([nn.Conv1d(in_channels=channel_ls[i], out_channels=channel_ls[i + 1],
                                             kernel_size=kernel_size_conv, stride=stride_conv, padding=padding_conv)
                                   for i in range(len(channel_ls) - 1)])
        if self.batch_norm:
            self.norm = nn.ModuleList([nn.BatchNorm1d(channel_ls[i + 1]) for i in range(len(channel_ls) - 1)])
        if self.max_pool:
            self.pool = nn.ModuleList([nn.MaxPool1d(kernel_size=kernel_size_pool, stride=stride_pool,
                                                    padding=padding_pool) for _ in range(len(channel_ls) - 1)])
        if self.flatten:
            self.flat = nn.Flatten()

    def forward(self, f):
        f = self.embed(f) if self.embedding else F.one_hot(f.to(torch.int64), self.num_embedding).float()
        f = f.permute(0, 2, 1).contiguous()

        for i in range(len(self.conv)):
            f = self.conv[i](f)
            if self.batch_norm:
                f = self.norm[i](f)
            f = F.relu(f)
            if self.max_pool:
                f = self.pool[i](f)

        if self.flatten:
            f = self.flat(f)

        if self.debug:
            print(f.shape)
        return f