Source code for torchdecomp.nmf

import torch
import torch.nn as nn
from .helper import _check_torch_tensor, _check_dimension, _rho


[docs] class NMFLayer(nn.Module): """Non-negative Matrix Factorization Layer A non-negative matrix X (n times m) is decomposed to the product of W (n times k) and H (k times m). Attributes: x (torch.Tensor): A non-negative matrix X (n times m) n_components (int): The number of lower dimensions (k) l1_lambda_w (float): L1 regularization parameter for W l1_lambda_h (float): L1 regularization parameter for H l2_lambda_w (float): L2 regularization parameter for W l2_lambda_h (float): L2 regularization parameter for H bin_lambda_w (float): Binarization regularization parameter for W bin_lambda_h (float): Binarization regularization parameter for H eps (float): Offset value to avoid zero division beta (float): Beta parameter of Beta-divergence Example: >>> import torchdecomp as td >>> import torch >>> torch.manual_seed(123456) >>> x = torch.randn(10, 6) # Test datasets >>> nmf_layer = td.NMFLayer(x, 3) # Instantiation """ def __init__( self, x, n_components, l1_lambda_w=torch.finfo(torch.float64).eps, l1_lambda_h=torch.finfo(torch.float64).eps, l2_lambda_w=torch.finfo(torch.float64).eps, l2_lambda_h=torch.finfo(torch.float64).eps, bin_lambda_w=torch.finfo(torch.float64).eps, bin_lambda_h=torch.finfo(torch.float64).eps, eps=torch.finfo(torch.float64).eps, beta=2): """Initialization function """ super(NMFLayer, self).__init__() _check_torch_tensor(x) size0 = x.size(0) size1 = x.size(1) _check_dimension(size0, n_components) _check_dimension(size1, n_components) self.eps = eps self.W = nn.Parameter(torch.rand( size0, n_components, dtype=torch.float64)) self.H = nn.Parameter(torch.rand( n_components, size1, dtype=torch.float64)) self.l1_lambda_w = l1_lambda_w self.l1_lambda_h = l1_lambda_h self.l2_lambda_w = l2_lambda_w self.l2_lambda_h = l2_lambda_h self.bin_lambda_w = bin_lambda_w self.bin_lambda_h = bin_lambda_h self.beta = beta
[docs] def positive(self, X, WH, beta): """Positive Terms of Beta-NMF Object Function """ if beta == 0: return X / WH if beta == 1: return (1 / (beta + 0.001)) * (WH**(beta + 0.001)) else: return (1 / beta) * (WH**beta)
[docs] def negative(self, X, WH, beta): """Negative Terms of Beta-NMF Object Function """ if beta == 0: return torch.log(X / WH) if beta == 1: return (1 / (beta - 0.999)) * (X * (WH**(beta - 0.999))) else: return (1 / (beta - 1)) * (X * (WH**(beta - 1)))
[docs] def positive_w(self, W, l1_lambda_w, l2_lambda_w, bin_lambda_w): """Positive Terms of L2 regularization against W """ l1_term = l1_lambda_w * W l2_term = l2_lambda_w * W**2 bin_term = bin_lambda_w * (W**4 + W**2) return l1_term + l2_term + bin_term
[docs] def negative_w(self, W, bin_lambda_w): """Negative Terms of L2 regularization against W """ bin_term = bin_lambda_w * 2 * W**3 return bin_term
[docs] def positive_h(self, H, l1_lambda_h, l2_lambda_h, bin_lambda_h): """Positive Terms of L2 regularization against H """ l1_term = l1_lambda_h * H l2_term = l2_lambda_h * H**2 bin_term = bin_lambda_h * (H**4 + H**2) return l1_term + l2_term + bin_term
[docs] def negative_h(self, H, bin_lambda_h): """Negative Terms of L2 regularization against H """ bin_term = bin_lambda_h * 2 * H**3 return bin_term
[docs] def loss(self, pos, neg, pos_w, neg_w, pos_h, neg_h): """Total Loss with the recontruction term and regularization terms """ loss1 = torch.sum(pos - neg) loss2 = torch.sum(pos_w - neg_w) loss3 = torch.sum(pos_h - neg_h) return loss1 + loss2 + loss3
[docs] def forward(self, X): """Forward propagation function """ WH = torch.mm(self.W, self.H) WH[WH < self.eps] = self.eps pos = self.positive(X, WH, self.beta) neg = self.negative(X, WH, self.beta) pos_w = self.positive_w( self.W, self.l1_lambda_w, self.l2_lambda_w, self.bin_lambda_w) neg_w = self.negative_w(self.W, self.bin_lambda_w) pos_h = self.positive_h( self.H, self.l1_lambda_h, self.l2_lambda_h, self.bin_lambda_h) neg_h = self.negative_h(self.H, self.bin_lambda_h) loss = self.loss(pos, neg, pos_w, neg_w, pos_h, neg_h) return loss, WH, pos, neg, pos_w, neg_w, pos_h, neg_h
[docs] def gradNMF(WH, pos, neg, pos_w, neg_w, pos_h, neg_h, nmf_layer): grad_pos = torch.autograd.grad( pos, WH, grad_outputs=torch.ones_like(pos))[0] grad_neg = torch.autograd.grad( neg, WH, grad_outputs=torch.ones_like(neg))[0] grad_pos_w = torch.autograd.grad( pos_w, nmf_layer.W, grad_outputs=torch.ones_like(pos_w))[0] grad_neg_w = torch.autograd.grad( neg_w, nmf_layer.W, grad_outputs=torch.ones_like(neg_w))[0] grad_pos_h = torch.autograd.grad( pos_h, nmf_layer.H, grad_outputs=torch.ones_like(pos_h))[0] grad_neg_h = torch.autograd.grad( neg_h, nmf_layer.H, grad_outputs=torch.ones_like(neg_h))[0] return grad_pos, grad_neg, grad_pos_w, grad_neg_w, grad_pos_h, grad_neg_h
[docs] def updateNMF( grad_pos, grad_neg, grad_pos_w, grad_neg_w, grad_pos_h, grad_neg_h, nmf_layer, beta=2): # Copy W = nmf_layer.W.data.detach() H = nmf_layer.H.data.detach() # Update W *= ( (torch.mm(grad_neg, H.T) + grad_neg_w) / (torch.mm(grad_pos, H.T) + grad_pos_w))**_rho(beta=beta) H *= ( (torch.mm(W.T, grad_neg) + grad_neg_h) / (torch.mm(W.T, grad_pos) + grad_pos_h))**_rho(beta=beta) # Normalization W = torch.nn.functional.normalize(W, dim=0) H = torch.nn.functional.normalize(H, dim=1) return W, H