Callbacks with Trainer
1. Adding Callbacks to TrainConfig
To use callbacks in TrainConfig
, add them to the callbacks
list when configuring the training process.
from perceptrain import TrainConfig
from perceptrain.callbacks import SaveCheckpoint, PrintMetrics
config = TrainConfig(
max_iter=10000,
callbacks=[
SaveCheckpoint(on="val_epoch_end", called_every=50),
PrintMetrics(on="train_epoch_end", called_every=100),
]
)
2. Using Callbacks with Trainer
The Trainer
class in perceptrain
provides built-in support for executing callbacks at various stages in the training process, managed through a callback manager. By default, several callbacks are added to specific hooks to automate common tasks, such as check-pointing, metric logging, and model tracking.
Default Callbacks
Below is a list of the default callbacks and their assigned hooks:
train_start
:WritePlots
,SaveCheckpoint
,WriteMetrics
train_epoch_end
:SaveCheckpoint
,PrintMetrics
,WritePlots
,LivePlotMetrics
,WriteMetrics
val_epoch_end
:SaveBestCheckpoint
,WriteMetrics
train_end
:LogHyperparameters
,LogModelTracker
,WriteMetrics
,SaveCheckpoint
,WritePlots
These defaults handle common needs, but you can also add custom callbacks to any hook.
Example: Adding a Custom Callback
To create a custom Trainer
that includes a PrintMetrics
callback executed specifically at the end of each epoch, follow the steps below.
from perceptrain.trainer import Trainer
from perceptrain.callbacks import PrintMetrics
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.print_metrics_callback = PrintMetrics(on="train_epoch_end", called_every = 10)
def on_train_epoch_end(self, train_epoch_loss_metrics):
self.print_metrics_callback.run_callback(self)