Source code for tests.test_nmf

import torchdecomp as td
import torch
import pytest
import numpy as np


[docs] def test_NMFLayer(): x = torch.randn(10, 6) nmf_layer = td.NMFLayer(x, 3) assert nmf_layer.W.size()[0] == 10 assert nmf_layer.W.size()[1] == 3 assert nmf_layer.H.size()[0] == 3 assert nmf_layer.H.size()[1] == 6
[docs] def test_NMFLayer_error(): x = np.random.rand(10, 6) with pytest.raises(AssertionError) as exc_info: td.NMFLayer(x, 3) assert exc_info.type == AssertionError