Source code for torchGB.layers.gnet.model

from typing import Callable, Optional, Sequence, Tuple

import torch
import torch.nn as nn
from torch import Tensor

import numpy as np

from .pinv_fc_layer import PseudoInverseLinear


[docs]class GenomicBottleNet(nn.Module): """ Improved version of the variable-length g-net that uses a for-loop for initialization. Args: layers (nn.ModuleList): ModuleList that contains all differentiable layers of the g-net. sizes (Sequence[int]): List of sizes for the g-net layers. output_scale (float): Scaling factor for the output of the g-net. activation_fn (Optional[Callable[[Tensor], Tensor]]): Activation function for the hidden layers. Default is ReLU. Returns: Tensor: Prediction of the new weight. """ model: nn.Sequential sizes: Sequence[int] output_scale: float def __init__(self, sizes: Sequence[int], output_scale: float, activation_fn: Optional[Callable[[Tensor], Tensor]] = nn.ReLU) -> None: super().__init__() self.output_scale = output_scale.detach() self.activation_fn = activation_fn self.sizes = sizes length = len(sizes) - 1 # no non-linearity on the last layer layer_list = [] for i in range(length-1): layer_list.append(nn.Linear(sizes[i], sizes[i+1])) layer_list.append(activation_fn()) layer_list.append(PseudoInverseLinear(sizes[-2], sizes[-1])) # layer_list.append(nn.Linear(sizes[-2], sizes[-1])) self.model = nn.Sequential(*layer_list) self.init_weights()
[docs] def init_weights(self) -> None: for layer in self.model: if isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, mean=0, std=1e-1) # initialization here is key! if layer.bias is not None: nn.init.zeros_(layer.bias)
[docs] def forward(self, x: Tensor) -> Tensor: # NOTE: Do not touch output norm! Carefully computed by hand... return self.model(x) * torch.tensor(2.) * self.output_scale
GNetLayerTuple = Tuple[Tensor, Sequence[GenomicBottleNet], Sequence[int], float]
[docs]class StochasticGenomicBottleNet(GenomicBottleNet): """ Improved version of the variable-length g-net that uses a for-loop for initialization. Args: layers (nn.ModuleList): ModuleList that contains all differentiable layers of the g-net. sizes (Sequence[int]): List of sizes for the g-net layers. output_scale (float): Scaling factor for the output of the g-net. activation_fn (Optional[Callable[[Tensor], Tensor]]): Activation function for the hidden layers. Default is ReLU. Returns: Tensor: Prediction of the new weight. """ def __init__(self, sizes: Sequence[int], output_scale: Tensor, activation_fn: Optional[Callable[[Tensor], Tensor]] = nn.ReLU) -> None: super().__init__(sizes, output_scale, activation_fn=activation_fn) length = len(sizes) - 1 layer_list = [] for i in range(length): layer_list.append(nn.Linear(sizes[i], sizes[i+1])) layer_list.append(activation_fn()) layer_list.pop(-1) # no non-linearity on the last layer self.model = nn.Sequential(*layer_list) self.init_weights() self.model[-1].bias.data = torch.tensor([0.01, np.log(0.02, dtype=np.float32)])
[docs] def init_weights(self) -> None: for layer in self.model: if isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, mean=0, std=1e-2) # initialization here is key! if layer.bias is not None: nn.init.zeros_(layer.bias)
[docs] def forward(self, x: Tensor) -> Tensor: out = self.model(x) mu, logsigma = out[:, 0], out[:, 1] eps = torch.randn_like(logsigma).detach() return (mu + torch.exp(logsigma) * eps) * torch.tensor(2.) * self.output_scale
[docs]class Reshape(nn.Module): def __init__(self, *shape) -> None: super().__init__() self.shape = shape
[docs] def forward(self, x: Tensor) -> Tensor: return x.view(x.size(0), *self.shape) # Preserve batch dimension
[docs]class FastGenomicBottleNet(GenomicBottleNet): """_summary_ TODO: generalize the implementation? TODO: use this implementation for fast computation of a set of adjacent tiles Args: nn (_type_): _description_ Returns: _type_: _description_ """ num_tiles: int def __init__(self, num_tiles: int, sizes: Sequence[int], output_scale: Tensor, activation_fn: nn.Module = nn.ReLU) -> None: super().__init__(sizes, output_scale, activation_fn=activation_fn) self.num_tiles = num_tiles layer_list = [nn.Conv1d(1, num_tiles*sizes[1], kernel_size=sizes[0])] layer_list.append(activation_fn()) # NOTE: The `groups` argument in the following layers ensures that each # tile has its own set of weights. layer_list.append(Reshape(num_tiles, sizes[1])) layer_list.append(nn.Conv1d(num_tiles, num_tiles, kernel_size=sizes[1], groups=num_tiles)) self.model = nn.Sequential(*layer_list)
[docs] def init_weights(self) -> None: for layer in self.model: if isinstance(layer, nn.Conv1d): nn.init.normal_(layer.weight, mean=0, std=1e-1) # initialization here is key! if layer.bias is not None: nn.init.zeros_(layer.bias)
[docs] def forward(self, x: Tensor) -> Tensor: return self.model(x) * torch.tensor(2.) * self.output_scale
[docs]class FastStochasticGenomicBottleNet(FastGenomicBottleNet): """_summary_ TODO: generalize the implementation? TODO: use this implementation for fast computation of a set of adjacent tiles Args: nn (_type_): _description_ Returns: _type_: _description_ """ num_tiles: int def __init__(self, num_tiles: int, sizes: Sequence[int], output_scale: Tensor, activation_fn: nn.Module = nn.ReLU) -> None: super().__init__(num_tiles, sizes, output_scale, activation_fn=activation_fn) layer_list = [nn.Conv1d(1, num_tiles*sizes[1], kernel_size=sizes[0])] layer_list.append(activation_fn()) # NOTE: The `groups` argument in the following layers ensures that each # tile has its own set of weights. layer_list.append(Reshape(num_tiles, sizes[1])) layer_list.append(nn.Conv1d(num_tiles, 2*num_tiles, kernel_size=sizes[1], groups=num_tiles)) self.model = nn.Sequential(*layer_list) biases = torch.tensor([0.0, np.log(0.02, dtype=np.float32)]) self.model[-1].bias.data = torch.repeat_interleave(biases, num_tiles)
[docs] def init_weights(self) -> None: for layer in self.model: if isinstance(layer, nn.Conv1d): nn.init.normal_(layer.weight, mean=0, std=1e-2) # initialization here is key! if layer.bias is not None: nn.init.zeros_(layer.bias)
[docs] def forward(self, x: Tensor) -> Tensor: # NOTE: out = self.model(x) mu, logsigma = out[:, ::2], out[:, 1::2] eps = torch.randn_like(logsigma).detach() return (mu + torch.exp(logsigma) * eps) * torch.tensor(2.) * self.output_scale