CellEncoder.DAE

Click here to go back to the reference.

class DAE(nn.Module):
    def __init__(self, subset: bool = True, path: str = None):
        """"""
        super(DAE, self).__init__()
        if subset:
            self.encoder = DNN(in_dim=6163, ft_dim=100)
            self.encoder.load_state_dict(torch.load(os.path.join(os.path.split(__file__)[0], 'DefaultData/DAE.pt')))
            # self.decoder = DNN(in_dim=100, ft_dim=6163)
        else:
            assert path is not None
            self.encoder = DNN(in_dim=17420, hid_dim=512, num_layers=3, ft_dim=100)
            self.encoder.load_state_dict(torch.load(path))
            # self.decoder = DNN(in_dim=100, ft_dim=17420)

    def forward(self, f):
        return self.encoder(f)