Source code for torchdecomp.symrec

import torch
import torch.nn as nn
from .helper import _check_torch_tensor, _check_symmetric_matrix
from .helper import _check_dimension


[docs] class SymRecLayer(nn.Module): """Symmetric Reconstruction Layer A symmetric matrix X (n times n) is decomposed to the product of Q (n times k), Lambda (k times k), and Q^T (k times n). Attributes: x (torch.Tensor): A symmetric matrix X (n times n) n_components (int): The number of lower dimensions (k) Example: >>> import torchdecomp as td >>> import torch >>> torch.manual_seed(123456) >>> x = torch.randn(6, 6) # Test datasets >>> x = torch.mm(x, x.t()) # Symmetalization >>> symrec_layer = td.SymRecLayer(x, 3) # Instantiation """ def __init__(self, x, n_components): """Initialization function """ super(SymRecLayer, self).__init__() _check_torch_tensor(x) _check_symmetric_matrix(x) size = x.size(0) _check_dimension(size, n_components) Q = torch.nn.init.orthogonal_(torch.randn(size, n_components), gain=1) Lambda = torch.diag(torch.sort( torch.randn(n_components)**2, descending=True).values) self.Q = nn.Parameter(Q) self.Lambda = nn.Parameter(Lambda)
[docs] def forward(self): """Forward propagation function """ return torch.mm(torch.mm(self.Q, self.Lambda), self.Q.t())