Source code for torchdecomp.helper

import sys
import os
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


# Helper functions
def _check_dimension(size, n_components):
    """An internal function used only in the torchdecomp module
    """
    assert size >= n_components, 'Specify n_components as a smaller value'


def _check_torch_tensor(x):
    """An internal function used only in the torchdecomp module
    """
    assert isinstance(x, torch.Tensor), 'Specify torch.Tensor as input'


def _check_square_matrix(x):
    """An internal function used only in the torchdecomp module
    """
    size = x.size()
    assert size[0] == size[1], 'Specify input as a square matrix'


def _check_symmetric_matrix(x):
    """An internal function used only in the torchdecomp module
    """
    check1 = len(x) == len(x[0])
    check2 = all(
        x[i][j] == x[j][i] for i in range(len(x)) for j in range(len(x)))
    assert (check1 & check2), 'Specify input as a symmetric matrix'


[docs] def create_dummy_matrix(class_vector): """Creates a dummy matrix from a class label vector. Args: class_vector: A PyTorch array with numeric elements Returns: A PyTorch array filled with dummy vectors Example: >>> import torchdecomp as td >>> td.create_dummy_matrix(torch.tensor([0, 1, 2, 1, 0, 2, 1, 0])) Note: The number of rows is the number of classes and the number of columns is the number of data. """ unique_classes = torch.unique(class_vector) num_data = len(class_vector) num_classes = len(unique_classes) dummy_matrix = torch.zeros((num_data, num_classes), dtype=torch.float32) for i, class_label in enumerate(unique_classes): class_indices = (class_vector == class_label).nonzero().view(-1) dummy_matrix[class_indices, i] = 1.0 return dummy_matrix
# Disable def _blockPrint(): sys.stdout = open(os.devnull, 'w') # Restore def _enablePrint(): sys.stdout = sys.__stdout__ def _plot(imgs, row_title=None, **imshow_kwargs): if not isinstance(imgs[0], list): # Make a 2d grid even if there's just 1 row imgs = [imgs] num_rows = len(imgs) num_cols = len(imgs[0]) _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) for row_idx, row in enumerate(imgs): for col_idx, img in enumerate(row): boxes = None masks = None if isinstance(img, tuple): img, target = img if isinstance(target, dict): boxes = target.get("boxes") masks = target.get("masks") elif isinstance(target, tv_tensors.BoundingBoxes): boxes = target else: raise ValueError(f"Unexpected target type: {type(target)}") img = F.to_image(img) if img.dtype.is_floating_point and img.min() < 0: # Poor man's re-normalization for the colors to be OK-ish. This # is useful for images coming out of Normalize() img -= img.min() img /= img.max() img = F.to_dtype(img, torch.uint8, scale=True) if boxes is not None: img = draw_bounding_boxes(img, boxes, colors="yellow", width=3) if masks is not None: img = draw_segmentation_masks( img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65) ax = axs[row_idx, col_idx] ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs) ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) if row_title is not None: for row_idx in range(num_rows): axs[row_idx, 0].set(ylabel=row_title[row_idx]) plt.tight_layout() def _rho(beta, root=False): if root: out = 0.5 else: if beta < 1: out = 1 / (2 - beta) if (1 <= beta) & (beta <= 2): out = 1 if beta > 2: out = 1 / (beta - 1) return out