Skip to content

Loss Functions

Perceptrain works with loss functions having the following interface:

def loss_fn(batch: Any, model: torch.nn.Module) -> tuple[torch.Tensor, dict]:
    ...

Therefore, loss functions from torch.nn won't work out of the box.

Users can either choose a built-in loss function or define a custom one.

Built-in loss functions

  • mse_loss: wrapper around torch.nn.MSELoss().
  • cross_entropy_loss: wrapper around torch.nn.CrossEntropyLoss().
  • GradWeigthedLoss: dynamic loss function that implements the learning rate annealing algorithm introduced here. It allows to redistribute weight on those metrics with smaller gradients, by updating the weight by the ratio of maximum derivative of a fixed metric and the mean partial derivative value of the metric being re-weighted. This trick can prevent falling into trivial local minima.

Custom loss functions

Users can define custom loss functions tailored to their specific tasks. The Trainer accepts a loss_fn parameter, which should be a callable that takes the data batch and the model as inputs and returns a tuple containing the loss tensor and a dictionary of metrics.

Example of using a custom loss function:

import torch
from itertools import count
cnt = count()
criterion = torch.nn.MSELoss()

def loss_fn_custom(batch: tuple[torch.Tensor, torch.Tensor] , model: torch.nn.Module) -> tuple[torch.Tensor, dict]:
    next(cnt)
    x, y = batch
    out = model(x)
    loss = criterion(out, y)
    return loss, {}

The custom loss function can be used in the trainer

from perceptrain import Trainer, TrainConfig
from torch.optim import Adam

# Initialize model and optimizer
model = ...  # Define or load a model here
optimizer = Adam(model.parameters(), lr=0.01)
config = TrainConfig(max_iter=100, print_every=10)

trainer = Trainer(model=model, optimizer=optimizer, config=config, loss_fn=loss_fn_custom)

NOTE: when working with custom loss functions, you must make sure that type of your data batch is compatible with the model. For instance, batch is a Tensor if model is a FFNN (feed-forward neural network), but it must be a dict[str, Tensor] if model is a PINN (physics-informed neural network).