torchGB.layers.xox package

Submodules

torchGB.layers.xox.attn_xox module

torchGB.layers.xox.conv_xox module

torchGB.layers.xox.linear_xox module

torchGB.layers.xox.model module

class torchGB.layers.xox.model.XOXLayer(num_input: int, num_output: int, num_genes: int)[source]

Bases: Module

Implements a linear layer with a low-rank factorization through a central matrix ‘O’.

This layer performs a linear transformation using a factorization of the weight matrix into three smaller matrices: X_input, O, and X_output. The forward pass calculates output = scaling_input * X_output @ O @ X_input.T @ input. This structure aims to reduce the number of parameters compared to a standard linear layer, especially when the dimension of ‘O’ (num_genes) is much smaller than the input and output dimensions. :param num_input: The dimensionality of the input. :type num_input: int :param num_output: The dimensionality of the output. :type num_output: int :param num_genes: The dimensionality of the central matrix ‘O’, controlling the rank of the factorization.

This acts as a bottleneck, hence the “Genomic” nomenclature referencing the idea of a compressed genetic representation.

calc_ref_scaling(init_fn=<function kaiming_normal_>) Tensor[source]
forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

output_scale: float
sizes: Sequence[int]
torchGB.layers.xox.model.std(array: Tensor) Tensor[source]

Module contents