Source code for torchdecomp.cholesky

import torch
import torch.nn as nn
from .helper import _check_torch_tensor, _check_symmetric_matrix


[docs] class CholeskyLayer(nn.Module): """Cholesky Decomposition Layer A symmetric matrix X (n times n) is decomposed to the product of L (n times n) and L^T (n times n). Attributes: x (torch.Tensor): A symmetric matrix X (n times n) Example: >>> import torchdecomp as td >>> import torch >>> torch.manual_seed(123456) >>> x = torch.randn(6, 6) # Test datasets >>> x = torch.mm(x, x.t()) # Symmetalization >>> cholesky_layer = td.CholeskyLayer(x) # Instantiation """ def __init__(self, x): """Initialization function """ super(CholeskyLayer, self).__init__() _check_torch_tensor(x) _check_symmetric_matrix(x) size = x.size() L = torch.tril(torch.randn(size)) # Set diagonal elements as positive values for i in range(min(size)): L[i, i] = torch.exp(L[i, i]) self.L = nn.Parameter(L)
[docs] def forward(self): """Forward propagation function """ return torch.mm(self.L, self.L.t())