Shortcuts

Source code for ignite.contrib.metrics.regression._base

from abc import abstractmethod
from typing import Tuple

import torch

from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced


def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]) -> None:
    y_pred, y = output
    c1 = y_pred.ndimension() == 2 and y_pred.shape[1] == 1
    if not (y_pred.ndimension() == 1 or c1):
        raise ValueError(f"Input y_pred should have shape (N,) or (N, 1), but given {y_pred.shape}")

    c2 = y.ndimension() == 2 and y.shape[1] == 1
    if not (y.ndimension() == 1 or c2):
        raise ValueError(f"Input y should have shape (N,) or (N, 1), but given {y.shape}")

    if y_pred.shape != y.shape:
        raise ValueError(f"Input data shapes should be the same, but given {y_pred.shape} and {y.shape}")


def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]) -> None:
    y_pred, y = output
    if y_pred.dtype not in (torch.float16, torch.float32, torch.float64):
        raise TypeError(f"Input y_pred dtype should be float 16, 32 or 64, but given {y_pred.dtype}")

    if y.dtype not in (torch.float16, torch.float32, torch.float64):
        raise TypeError(f"Input y dtype should be float 16, 32 or 64, but given {y.dtype}")


class _BaseRegression(Metric):
    # Base class for all regression metrics
    # `update` method check the shapes and call internal overloaded
    # method `_update`.

    @reinit__is_reduced
    def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        _check_output_shapes(output)
        _check_output_types(output)
        y_pred, y = output[0].detach(), output[1].detach()

        if y_pred.ndimension() == 2 and y_pred.shape[1] == 1:
            y_pred = y_pred.squeeze(dim=-1)

        if y.ndimension() == 2 and y.shape[1] == 1:
            y = y.squeeze(dim=-1)

        self._update((y_pred, y))

    @abstractmethod
    def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        pass

© Copyright 2021, PyTorch-Ignite Contributors. Last updated on 08/06/2021, 9:33:07 AM.

Built with Sphinx using a theme provided by Read the Docs.