torchGB.layers.xox package
Submodules
torchGB.layers.xox.attn_xox module
torchGB.layers.xox.conv_xox module
torchGB.layers.xox.linear_xox module
torchGB.layers.xox.model module
- class torchGB.layers.xox.model.XOXLayer(num_input: int, num_output: int, num_genes: int)[source]
Bases:
ModuleImplements a linear layer with a low-rank factorization through a central matrix ‘O’.
This layer performs a linear transformation using a factorization of the weight matrix into three smaller matrices: X_input, O, and X_output. The forward pass calculates output = scaling_input * X_output @ O @ X_input.T @ input. This structure aims to reduce the number of parameters compared to a standard linear layer, especially when the dimension of ‘O’ (num_genes) is much smaller than the input and output dimensions. :param num_input: The dimensionality of the input. :type num_input: int :param num_output: The dimensionality of the output. :type num_output: int :param num_genes: The dimensionality of the central matrix ‘O’, controlling the rank of the factorization.
This acts as a bottleneck, hence the “Genomic” nomenclature referencing the idea of a compressed genetic representation.
- forward(x: Tensor) Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.