Source code for torchGB.layers.xox.model

from typing import Sequence

import torch
import torch.nn as nn
from torch import Tensor

import numpy as np


[docs]def std(array: Tensor) -> Tensor: return array.std().item() + 1e-9
[docs]class XOXLayer(nn.Module): """Implements a linear layer with a low-rank factorization through a central matrix 'O'. This layer performs a linear transformation using a factorization of the weight matrix into three smaller matrices: X_input, O, and X_output. The forward pass calculates output = scaling_input * X_output @ O @ X_input.T @ input. This structure aims to reduce the number of parameters compared to a standard linear layer, especially when the dimension of 'O' (num_genes) is much smaller than the input and output dimensions. Args: num_input (int): The dimensionality of the input. num_output (int): The dimensionality of the output. num_genes (int): The dimensionality of the central matrix 'O', controlling the rank of the factorization. This acts as a bottleneck, hence the "Genomic" nomenclature referencing the idea of a compressed genetic representation. """ sizes: Sequence[int] output_scale: float def __init__(self, num_input: int, num_output: int, num_genes: int) -> None: super().__init__() self.sizes = (num_input, num_genes, num_output) self.output_scale = torch.tensor(1.0) self.O_mat = torch.nn.Parameter((1.0 / num_genes) * torch.randn(num_genes, num_genes)) norm = (1.0 / np.sqrt(3 * num_genes)) self.X_input = torch.nn.Parameter(norm * torch.randn(np.prod(num_input), num_genes)) self.X_output = torch.nn.Parameter(norm * torch.randn(np.prod(num_output), num_genes)) # Calculate a scaling factor based on the initialization scheme. This helps normalize # the output variance and likely improves training stability. The `calculate_reference_scaling` # function attempts to match the standard deviation of the initialized weights to that # of a Kaiming He initialized layer, a common initialization strategy for ReLU networks. # NOTE: this implementation is unnecessary complicated. Since we know the # stddevs of X, O, Y we could replace this with a more efficient solution # that directly computes the scaling factor. self.scaling_input = nn.Parameter(self.calc_ref_scaling())
[docs] def calc_ref_scaling(self, init_fn=torch.nn.init.kaiming_normal_) -> Tensor: # NOTE: Not sure if Kaiming is the best init! Rather use Xavier... weights = self.X_output @ self.O_mat @ self.X_input.T ref_weights = torch.empty(weights.size()) ref_weights = init_fn(ref_weights) return torch.tensor(std(ref_weights) / std(weights))
[docs] def forward(self, x: Tensor) -> Tensor: # NOTE: The current implementation performs the multiplication as X_output @ O @ X_input.T # which doesn't seem correct for a standard linear layer replacement unless x is already flattened. # A typical linear layer implementation would expect x to be (batch_size, input_dim) and perform # something like (x @ W.T) + b. The .T in the current implementation suggests the expected # input might be (input_dim, batch_size). This should be clarified and potentially corrected # to make the layer truly drop-in compatible. Also a bias term is missing. See `TODO` in the original code. return self.scaling_input * self.X_output @ self.O_mat @ self.X_input.T