Skip to content

Loss functions

Loss functions

GradWeightedLoss(batch, unweighted_loss_function, optimizer, metric_weights, fixed_metric, alpha=0.9)

Loss function with gradient weighting for PINN training.

Implements the learning rate annealing algorithm in this article.

PARAMETER DESCRIPTION
batch

Batch of data.

TYPE: dict[str, Tensor]

unweighted_loss_function

Loss function applied before weighting.

TYPE: LossFunction

optimizer

torch or nevergrad optimizer for gradient or gradient-free optimization.

TYPE: Optimizer | Optimizer

metric_weights

Initial metric weights.

TYPE: dict[str, float | Tensor]

fixed_metric

Metric whose weight is not updated and whose gradient determines the weights of the other metrics.

TYPE: str

alpha

Scaling factor. Corresponds to the inertia of the weights to updates. Defaults to 0.9.

TYPE: float DEFAULT: 0.9

Source code in perceptrain/loss/loss.py
def __init__(
    self,
    batch: dict[str, Tensor],
    unweighted_loss_function: LossFunction,
    optimizer: torch.optim.Optimizer | ng.optimization.Optimizer,
    metric_weights: dict[str, float | Tensor],
    fixed_metric: str,
    alpha: float = 0.9,
):
    """Loss function with gradient weighting for PINN training.

    Implements the learning rate annealing algorithm in [this article](https://arxiv.org/abs/2001.04536).

    Args:
        batch (dict[str, Tensor]): Batch of data.
        unweighted_loss_function (LossFunction): Loss function applied before weighting.
        optimizer (torch.optim.Optimizer | ng.optimization.Optimizer): torch or nevergrad
            optimizer for gradient or gradient-free optimization.
        metric_weights (dict[str, float | Tensor]): Initial metric weights.
        fixed_metric (str): Metric whose weight is not updated and whose gradient determines the
            weights of the other metrics.
        alpha (float, optional): Scaling factor. Corresponds to the inertia of the weights to
            updates. Defaults to 0.9.
    """
    self.metric_names = batch.keys()
    self.metric_weights = metric_weights
    self.gradients: dict[str, dict[str, Tensor]] = {key: {} for key in self.metric_names}
    self.unweighted_loss_function = unweighted_loss_function
    self.optimizer = optimizer
    self.fixed_metric = fixed_metric
    self.alpha = alpha

cross_entropy_loss(batch, model)

Cross Entropy Loss.

PARAMETER DESCRIPTION
batch

The input batch.

TYPE: TBatch

model

The model to compute the loss for.

TYPE: Module

RETURNS DESCRIPTION
tuple[Tensor, dict[str, Tensor]]

Tuple[Tensor, dict[str, float]]: - loss (Tensor): The computed loss value. - metrics (dict[str, float]): Empty dictionary. Not relevant for this loss function.

Source code in perceptrain/loss/loss.py
def cross_entropy_loss(batch: TBatch, model: nn.Module) -> tuple[Tensor, dict[str, Tensor]]:
    """Cross Entropy Loss.

    Args:
        batch (TBatch): The input batch.
        model (nn.Module): The model to compute the loss for.

    Returns:
        Tuple[Tensor, dict[str, float]]:
            - loss (Tensor): The computed loss value.
            - metrics (dict[str, float]): Empty dictionary. Not relevant for this loss function.
    """
    inputs, labels = batch
    predictions = model(inputs)
    metrics: dict[str, Tensor] = {}
    criterion = nn.CrossEntropyLoss()
    loss = criterion(predictions, labels)

    return loss, metrics

get_loss(loss_fn)

Returns the appropriate loss function based on the input argument.

PARAMETER DESCRIPTION
loss_fn

The loss function to use. - If loss_fn is a callable, it will be returned directly. - If loss_fn is a string, it should be one of: - "mse": Returns the MSE loss function. - "cross_entropy": Returns the Cross Entropy function. - If loss_fn is None, the default MSE loss function will be returned.

TYPE: str | Callable | None

RETURNS DESCRIPTION
Callable

The corresponding loss function.

TYPE: Callable

RAISES DESCRIPTION
ValueError

If loss_fn is a string but not a supported loss function name.

Source code in perceptrain/loss/loss.py
def get_loss(loss_fn: str | Callable | None) -> Callable:
    """
    Returns the appropriate loss function based on the input argument.

    Args:
        loss_fn (str | Callable | None): The loss function to use.
            - If `loss_fn` is a callable, it will be returned directly.
            - If `loss_fn` is a string, it should be one of:
                - "mse": Returns the MSE loss function.
                - "cross_entropy": Returns the Cross Entropy function.
            - If `loss_fn` is `None`, the default MSE loss function will be returned.

    Returns:
        Callable: The corresponding loss function.

    Raises:
        ValueError: If `loss_fn` is a string but not a supported loss function name.
    """
    if callable(loss_fn):
        return loss_fn
    elif isinstance(loss_fn, str):
        if loss_fn == "mse":
            return mse_loss
        elif loss_fn == "cross_entropy":
            return cross_entropy_loss
        else:
            raise ValueError(f"Unsupported loss function: {loss_fn}")
    else:
        # default case
        return mse_loss

mse_loss(batch, model)

Mean Squared Error Loss.

PARAMETER DESCRIPTION
batch

The input batch.

TYPE: TBatch

model

The model to compute the loss for.

TYPE: Module

RETURNS DESCRIPTION
tuple[Tensor, dict[str, Tensor]]

Tuple[Tensor, dict[str, float]]: - loss (Tensor): The computed loss value. - metrics (dict[str, float]): A dictionary of metrics (loss components).

Source code in perceptrain/loss/loss.py
def mse_loss(batch: TBatch, model: nn.Module) -> tuple[Tensor, dict[str, Tensor]]:
    """Mean Squared Error Loss.

    Args:
        batch (TBatch): The input batch.
        model (nn.Module): The model to compute the loss for.

    Returns:
        Tuple[Tensor, dict[str, float]]:
            - loss (Tensor): The computed loss value.
            - metrics (dict[str, float]): A dictionary of metrics (loss components).
    """
    return _compute_loss_and_metrics_based_on_model(batch, model, criterion=nn.MSELoss())  # type: ignore[no-any-return]