Source code for torchdecomp.factor

import torch
import torch.nn as nn
from .helper import _check_torch_tensor, _check_dimension


[docs] class FactorLayer(nn.Module): """Factor Matrix Layer A matrix X (n times m) is projected to a smaller matrix XV (n times k, k << m). Attributes: x (torch.Tensor): A matrix X (n times m) n_components (int): The number of lower dimensions (k) Example: >>> import torchdecomp as td >>> import torch >>> torch.manual_seed(123456) >>> x = torch.randn(10, 6) # Test datasets >>> factor_layer = td.FactorLayer(x, 3) # Instantiation """ def __init__(self, x, n_components): """Initialization function """ super(FactorLayer, self).__init__() _check_torch_tensor(x) size = x.size(1) _check_dimension(size, n_components) V = torch.nn.init.orthogonal_(torch.randn(size, n_components), gain=1) self.V = nn.Parameter(V)
[docs] def forward(self, x): """Forward propagation function """ return torch.mm(x, self.V)