Skip to content

Built-in Callbacks

perceptrain offers several built-in callbacks for common tasks like saving checkpoints, logging metrics, and tracking models. Below is an overview of each.

1. PrintMetrics

Prints metrics at specified intervals.

from perceptrain import TrainConfig
from perceptrain.callbacks import PrintMetrics

print_metrics_callback = PrintMetrics(on="val_batch_end", called_every=100)

config = TrainConfig(
    max_iter=10000,
    callbacks=[print_metrics_callback]
)

2. WriteMetrics

Writes metrics to a specified logging destination.

from perceptrain import TrainConfig
from perceptrain.callbacks import WriteMetrics

write_metrics_callback = WriteMetrics(on="train_epoch_end", called_every=50)

config = TrainConfig(
    max_iter=5000,
    callbacks=[write_metrics_callback]
)

3. WritePlots

Plots metrics based on user-defined plotting functions.

from perceptrain import TrainConfig
from perceptrain.callbacks import WritePlots

plot_metrics_callback = WritePlots(on="train_epoch_end", called_every=100)

config = TrainConfig(
    max_iter=5000,
    callbacks=[plot_metrics_callback]
)

3. LivePlotMetrics

Plots dynamically on screen the metrics followed during training. The arrange parameter allows for custom arrangement of subplots.

from perceptrain import TrainConfig
from perceptrain.callbacks import LivePlotMetrics

live_plot_callback = LivePlotMetrics(on="train_epoch_end",
    called_every=100,
    arrange={"training": ["train_loss", "train_metric_first"], "validation": ["val_loss", "val_metric_second"]},
)

config = TrainConfig(
    max_iter=5000,
    callbacks=[live_plot_callback]
)

4. LogHyperparameters

Logs hyperparameters to keep track of training settings.

from perceptrain import TrainConfig
from perceptrain.callbacks import LogHyperparameters

log_hyper_callback = LogHyperparameters(on="train_start", called_every=1)

config = TrainConfig(
    max_iter=1000,
    callbacks=[log_hyper_callback]
)

5. SaveCheckpoint

Saves model checkpoints at specified intervals.

from perceptrain import TrainConfig
from perceptrain.callbacks import SaveCheckpoint

save_checkpoint_callback = SaveCheckpoint(on="train_epoch_end", called_every=100)

config = TrainConfig(
    max_iter=10000,
    callbacks=[save_checkpoint_callback]
)

6. SaveBestCheckpoint

Saves the best model checkpoint based on a validation criterion.

from perceptrain import TrainConfig
from perceptrain.callbacks import SaveBestCheckpoint

save_best_checkpoint_callback = SaveBestCheckpoint(on="val_epoch_end", called_every=10)

config = TrainConfig(
    max_iter=10000,
    callbacks=[save_best_checkpoint_callback]
)

7. LoadCheckpoint

Loads a saved model checkpoint at the start of training.

from perceptrain import TrainConfig
from perceptrain.callbacks import LoadCheckpoint

load_checkpoint_callback = LoadCheckpoint(on="train_start")

config = TrainConfig(
    max_iter=10000,
    callbacks=[load_checkpoint_callback]
)

8. LogModelTracker

Logs the model structure and parameters.

from perceptrain import TrainConfig
from perceptrain.callbacks import LogModelTracker

log_model_callback = LogModelTracker(on="train_end")

config = TrainConfig(
    max_iter=1000,
    callbacks=[log_model_callback]
)

9. LRSchedulerStepDecay

Reduces the learning rate by a factor at regular intervals.

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerStepDecay

lr_step_decay = LRSchedulerStepDecay(on="train_epoch_end", called_every=100, gamma=0.5)

config = TrainConfig(
    max_iter=10000,
    callbacks=[lr_step_decay]
)

10. LRSchedulerCyclic

Applies a cyclic learning rate schedule during training.

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerCyclic

lr_cyclic = LRSchedulerCyclic(on="train_batch_end", called_every=1, base_lr=0.001, max_lr=0.01, step_size=2000)

config = TrainConfig(
    max_iter=10000,
    callbacks=[lr_cyclic]
)

11. LRSchedulerCosineAnnealing

Applies cosine annealing to the learning rate during training.

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerCosineAnnealing

lr_cosine = LRSchedulerCosineAnnealing(on="train_batch_end", called_every=1, t_max=5000, min_lr=1e-6)

config = TrainConfig(
    max_iter=10000,
    callbacks=[lr_cosine]
)

11. LRSchedulerReduceOnPlateau

Reduces the learning rate when a given metric does not improve for a number of epochs.

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerReduceOnPlateau

lr_plateau = LRSchedulerReduceOnPlateau(
    on="train_epoch_end",
    called_every=1,
    monitor="train_loss",
    patience=20,
    mode="min",
    gamma=0.5,
    threshold=1e-4,
    min_lr=1e-5,
)

config = TrainConfig(
    max_iter=10000,
    callbacks=[lr_plateau]
)

12. EarlyStopping

Stops training when a monitored metric has not improved for a specified number of epochs.

from perceptrain import TrainConfig
from perceptrain.callbacks import EarlyStopping

early_stopping = EarlyStopping(on="val_epoch_end", called_every=1, monitor="val_loss", patience=5, mode="min")

config = TrainConfig(
    max_iter=10000,
    callbacks=[early_stopping]
)

13. GradientMonitoring

Logs gradient statistics (e.g., mean, standard deviation, max) during training.

from perceptrain import TrainConfig
from perceptrain.callbacks import GradientMonitoring

gradient_monitoring = GradientMonitoring(on="train_batch_end", called_every=10)

config = TrainConfig(
    max_iter=10000,
    callbacks=[gradient_monitoring]
)

14. R3Sampling

Triggers the update of the dataset using the R3 sampling technique (ref. here).

The following example shows how to set-up R3 Sampling to learn a harmonic oscillator with physics-informed neural networks.

import torch

from perceptrain import TrainConfig
from perceptrain.callbacks import R3Sampling
from perceptrain.data import R3Dataset
from perceptrain.models import PINN

m = 1.0
k = 1.0

def uniform_1d(n: int):
    return torch.rand(size=(n, 1))

def harmonic_oscillator(x: torch.Tensor, model: torch.nn.Module) -> torch.Tensor:
    u = model(x)
    dudt = torch.autograd.grad(
        outputs=u,
        inputs=x,
        grad_outputs=torch.ones_like(u),
        create_graph=True,
        retain_graph=True,
    )[0]
    d2udt2 = torch.autograd.grad(
        outputs=dudt,
        inputs=x,
        grad_outputs=torch.ones_like(dudt),
    )[0]
    return m * d2udt2 - kappa * u

def fitness_function(x: torch.Tensor, model: PINN) -> torch.Tensor:
    return torch.linalg.vector_norm(harmonic_oscillator(x, model.nn), ord=2)

dataset = R3Dataset(
    proba_dist=uniform_1d,
    n_samples=20,
    release_threshold=1.0,
)

callback_r3 = R3Sampling(
    initial_dataset=dataset,
    fitness_function=fitness_function,
    called_every=100,
)

config = TrainConfig(
    max_iter=1000,
    callbacks=[callback_r3]
)