Source code for torchdecomp.qr

import torch
import torch.nn as nn
from .helper import _check_torch_tensor, _check_square_matrix


[docs] class QRLayer(nn.Module): """QR Decomposition Layer A square matrix X (n times n) is decomposed to the product of Q (n times n) and R (m times n). Attributes: x (torch.Tensor): A square matrix X (n times n) Example: >>> import torchdecomp as td >>> import torch >>> torch.manual_seed(123456) >>> x = torch.randn(6, 6) # Test datasets >>> qr_layer = td.QRLayer(x) # Instantiation """ def __init__(self, x): """Initialization function """ super(QRLayer, self).__init__() _check_torch_tensor(x) _check_square_matrix(x) size = x.size() Q = torch.nn.init.orthogonal_(torch.randn(size), gain=1) R = torch.triu(torch.randn(size), diagonal=0) self.Q = nn.Parameter(Q) self.R = nn.Parameter(R)
[docs] def forward(self): """Forward propagation function """ return torch.mm(self.Q, self.R)