Skip to content

Callbacks

Callbacks

Callback(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Base class for defining various training callbacks.

ATTRIBUTE DESCRIPTION
on

The event on which to trigger the callback. Must be a valid on value from: ["train_start", "train_end", "train_epoch_start", "train_epoch_end", "train_batch_start", "train_batch_end","val_epoch_start", "val_epoch_end", "val_batch_start", "val_batch_end", "test_batch_start", "test_batch_end"]

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int

callback

The function to call if the condition is met.

TYPE: CallbackFunction | None

callback_condition

Condition to check before calling.

TYPE: CallbackConditionFunction | None

modify_optimize_result

Function to modify OptimizeResult.

TYPE: CallbackFunction | dict[str, Any] | None

A callback can be defined in two ways:

  1. By providing a callback function directly in the base class: This is useful for simple callbacks that don't require subclassing.

Example:

from perceptrain.callbacks import Callback

def custom_callback_function(trainer, config, writer):
    print("Custom callback executed.")

custom_callback = Callback(
    on="train_end",
    called_every=5,
    callback=custom_callback_function
)

  1. By inheriting and implementing the run_callback method: This is suitable for more complex callbacks that require customization.

Example:

from perceptrain.callbacks import Callback
class CustomCallback(Callback):
    def run_callback(self, trainer, config, writer):
        print("Custom behavior in the inherited run_callback method.")

custom_callback = CustomCallback(on="train_end", called_every=10)

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

on property writable

Returns the TrainingStage.

RETURNS DESCRIPTION
TrainingStage

TrainingStage for the callback

TYPE: TrainingStage | str

__call__(when, trainer, config, writer)

Executes the callback if conditions are met.

PARAMETER DESCRIPTION
when

The event when the callback is triggered.

TYPE: str

trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

RETURNS DESCRIPTION
Any

Result of the callback function if executed.

TYPE: Any

Source code in perceptrain/callbacks/callback.py
def __call__(
    self, when: TrainingStage, trainer: Any, config: TrainConfig, writer: BaseWriter
) -> Any:
    """Executes the callback if conditions are met.

    Args:
        when (str): The event when the callback is triggered.
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.

    Returns:
        Any: Result of the callback function if executed.
    """
    opt_result = trainer.opt_result
    if self.on == when:
        if opt_result:
            opt_result = self.modify_optimize_result(opt_result)
        if self._should_call(when, opt_result):
            return self.run_callback(trainer, config, writer)

run_callback(trainer, config, writer)

Executes the defined callback.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

RETURNS DESCRIPTION
Any

Result of the callback execution.

TYPE: Any

RAISES DESCRIPTION
NotImplementedError

If not implemented in subclasses.

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Executes the defined callback.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.

    Returns:
        Any: Result of the callback execution.

    Raises:
        NotImplementedError: If not implemented in subclasses.
    """
    if self.callback is not None:
        return self.callback(trainer, config, writer)
    raise NotImplementedError("Subclasses should override the run_callback method.")

EarlyStopping(on, called_every, monitor, patience=5, mode='min')

Bases: Callback

Stops training when a monitored metric has not improved for a specified number of epochs.

This callback monitors a specified metric (e.g., validation loss or accuracy). If the metric does not improve for a given patience period, training is stopped.

Example Usage in TrainConfig: To use EarlyStopping, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import EarlyStopping

# Create an instance of the EarlyStopping callback
early_stopping = EarlyStopping(on="val_epoch_end",
                               called_every=1,
                               monitor="val_loss",
                               patience=5,
                               mode="min")

config = TrainConfig(
    max_iter=10000,
    print_every=1000,
    callbacks=[early_stopping]
)

Initializes the EarlyStopping callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback (e.g., "val_epoch_end").

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int

monitor

The metric to monitor (e.g., "val_loss" or "train_loss"). All metrics returned by optimize step are available to monitor. Please add "val_" and "train_" strings at the start of the metric name.

TYPE: str

patience

Number of iterations to wait for improvement. Default is 5.

TYPE: int DEFAULT: 5

mode

Whether to minimize ("min") or maximize ("max") the metric. Default is "min".

TYPE: str DEFAULT: 'min'

Source code in perceptrain/callbacks/callback.py
def __init__(
    self, on: str, called_every: int, monitor: str, patience: int = 5, mode: str = "min"
):
    """Initializes the EarlyStopping callback.

    Args:
        on (str): The event to trigger the callback (e.g., "val_epoch_end").
        called_every (int): Frequency of callback calls in terms of iterations.
        monitor (str): The metric to monitor (e.g., "val_loss" or "train_loss").
            All metrics returned by optimize step are available to monitor.
            Please add "val_" and "train_" strings at the start of the metric name.
        patience (int, optional): Number of iterations to wait for improvement. Default is 5.
        mode (str, optional): Whether to minimize ("min") or maximize ("max") the metric.
            Default is "min".
    """
    super().__init__(on=on, called_every=called_every)
    self.monitor = monitor
    self.patience = patience
    self.mode = mode
    self.best_value = float("inf") if mode == "min" else -float("inf")
    self.counter = 0

run_callback(trainer, config, writer)

Monitors the metric and stops training if no improvement is observed.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
    """
    Monitors the metric and stops training if no improvement is observed.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter): The writer object for logging.
    """
    current_value = trainer.opt_result.metrics.get(self.monitor)
    if current_value is None:
        raise ValueError(f"Metric '{self.monitor}' is not available in the trainer's metrics.")

    if (self.mode == "min" and current_value < self.best_value) or (
        self.mode == "max" and current_value > self.best_value
    ):
        self.best_value = current_value
        self.counter = 0
    else:
        self.counter += 1

    if self.counter >= self.patience:
        logger.info(
            f"EarlyStopping: No improvement in '{self.monitor}' for {self.patience} epochs. "
            "Stopping training."
        )
        trainer._stop_training.fill_(1)

GradientMonitoring(on, called_every=1)

Bases: Callback

Logs gradient statistics (e.g., mean, standard deviation, max) during training.

This callback monitors and logs statistics about the gradients of the model parameters to help debug or optimize the training process.

Example Usage in TrainConfig: To use GradientMonitoring, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import GradientMonitoring

# Create an instance of the GradientMonitoring callback
gradient_monitoring = GradientMonitoring(on="train_batch_end", called_every=10)

config = TrainConfig(
    max_iter=10000,
    print_every=1000,
    callbacks=[gradient_monitoring]
)

Initializes the GradientMonitoring callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback (e.g., "train_batch_end").

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int DEFAULT: 1

Source code in perceptrain/callbacks/callback.py
def __init__(self, on: str, called_every: int = 1):
    """Initializes the GradientMonitoring callback.

    Args:
        on (str): The event to trigger the callback (e.g., "train_batch_end").
        called_every (int): Frequency of callback calls in terms of iterations.
    """
    super().__init__(on=on, called_every=called_every)

run_callback(trainer, config, writer)

Logs gradient statistics.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
    """
    Logs gradient statistics.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter): The writer object for logging.
    """
    if trainer.accelerator.rank == 0:
        gradient_stats = {}
        for name, param in trainer.model.named_parameters():
            if param.grad is not None:
                grad = param.grad
                gradient_stats.update(
                    {
                        name + "_mean": grad.mean().item(),
                        name + "_std": grad.std().item(),
                        name + "_max": grad.max().item(),
                        name + "_min": grad.min().item(),
                    }
                )

        writer.write(trainer.opt_result.iteration, gradient_stats)

LRSchedulerCosineAnnealing(on, called_every, t_max, min_lr=0.0)

Bases: Callback

Applies cosine annealing to the learning rate during training.

This callback decreases the learning rate following a cosine curve, starting from the initial learning rate and annealing to a minimum (min_lr).

Example Usage in TrainConfig: To use LRSchedulerCosineAnnealing, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerCosineAnnealing

# Create an instance of the LRSchedulerCosineAnnealing callback
lr_cosine = LRSchedulerCosineAnnealing(on="train_batch_end",
                                       called_every=1,
                                       t_max=5000,
                                       min_lr=1e-6)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback
    callbacks=[lr_cosine]
)

Initializes the LRSchedulerCosineAnnealing callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback.

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int

t_max

The total number of iterations for one annealing cycle.

TYPE: int

min_lr

The minimum learning rate. Default is 0.0.

TYPE: float DEFAULT: 0.0

Source code in perceptrain/callbacks/callback.py
def __init__(self, on: str, called_every: int, t_max: int, min_lr: float = 0.0):
    """Initializes the LRSchedulerCosineAnnealing callback.

    Args:
        on (str): The event to trigger the callback.
        called_every (int): Frequency of callback calls in terms of iterations.
        t_max (int): The total number of iterations for one annealing cycle.
        min_lr (float, optional): The minimum learning rate. Default is 0.0.
    """
    super().__init__(on=on, called_every=called_every)
    self.t_max = t_max
    self.min_lr = min_lr

run_callback(trainer, config, writer)

Adjusts the learning rate using cosine annealing.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
    """
    Adjusts the learning rate using cosine annealing.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter): The writer object for logging.
    """
    for param_group in trainer.optimizer.param_groups:
        max_lr = param_group["lr"]
        new_lr = (
            self.min_lr
            + (max_lr - self.min_lr)
            * (1 + math.cos(math.pi * trainer.opt_result.iteration / self.t_max))
            / 2
        )
        param_group["lr"] = new_lr

LRSchedulerCyclic(on, called_every, base_lr, max_lr, step_size)

Bases: Callback

Applies a cyclic learning rate schedule during training.

This callback oscillates the learning rate between a minimum (base_lr) and a maximum (max_lr) over a defined cycle length (step_size). The learning rate follows a triangular wave pattern.

Example Usage in TrainConfig: To use LRSchedulerCyclic, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerCyclic

# Create an instance of the LRSchedulerCyclic callback
lr_cyclic = LRSchedulerCyclic(on="train_batch_end",
                              called_every=1,
                              base_lr=0.001,
                              max_lr=0.01,
                              step_size=2000)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback
    callbacks=[lr_cyclic]
)

Initializes the LRSchedulerCyclic callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback.

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int

base_lr

The minimum learning rate.

TYPE: float

max_lr

The maximum learning rate.

TYPE: float

step_size

Number of iterations for half a cycle.

TYPE: int

Source code in perceptrain/callbacks/callback.py
def __init__(self, on: str, called_every: int, base_lr: float, max_lr: float, step_size: int):
    """Initializes the LRSchedulerCyclic callback.

    Args:
        on (str): The event to trigger the callback.
        called_every (int): Frequency of callback calls in terms of iterations.
        base_lr (float): The minimum learning rate.
        max_lr (float): The maximum learning rate.
        step_size (int): Number of iterations for half a cycle.
    """
    super().__init__(on=on, called_every=called_every)
    self.base_lr = base_lr
    self.max_lr = max_lr
    self.step_size = step_size

run_callback(trainer, config, writer)

Adjusts the learning rate cyclically.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
    """
    Adjusts the learning rate cyclically.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter): The writer object for logging.
    """
    cycle = trainer.opt_result.iteration // (2 * self.step_size)
    x = abs(trainer.opt_result.iteration / self.step_size - 2 * cycle - 1)
    scale = max(0, (1 - x))
    new_lr = self.base_lr + (self.max_lr - self.base_lr) * scale
    for param_group in trainer.optimizer.param_groups:
        param_group["lr"] = new_lr

LRSchedulerReduceOnPlateau(on, called_every=1, monitor='train_loss', patience=20, mode='min', gamma=0.5, threshold=0.0001, min_lr=1e-06, verbose=True)

Bases: Callback

Reduces learning rate when a given metric reaches a plateau.

This callback decreases the learning rate by a factor gamma when a given metric does not improve after a given number of epochs by more than a given threshold, until a minimum learning rate is reached.

Example Usage in TrainConfig: To use LRSchedulerReduceOnPlateau, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerReduceOnPlateau

# Create an instance of the LRSchedulerReduceOnPlateau callback
lr_plateau = LRSchedulerReduceOnPlateau(
                on="train_epoch_end",
                called_every=1,
                monitor="train_loss",
                patience=20,
                mode="min",
                gamma=0.5,
                threshold=1e-4,
                min_lr=1e-5,
            )

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback
    callbacks=[lr_plateau]
)

Initializes the LRSchedulerReduceOnPlateau callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback. Default is train_epoch_end.

TYPE: str

called_every

Frequency of callback calls in terms of iterations. Default is 1.

TYPE: int DEFAULT: 1

monitor

The metric to monitor (e.g., "val_loss" or "train_loss"). All metrics returned by optimize step are available to monitor. Please add "val_" and "train_" strings at the start of the metric name. Default is "train_loss".

TYPE: str DEFAULT: 'train_loss'

mode

Whether to minimize ("min") or maximize ("max") the metric. Default is "min".

TYPE: str DEFAULT: 'min'

patience

Number of allowed iterations with no loss improvement before reducing the learning rate. Default is 20.

TYPE: int DEFAULT: 20

gamma

The decay factor applied to the learning rate. A value < 1 reduces the learning rate over time. Default is 0.5.

TYPE: float DEFAULT: 0.5

threshold

Amount by which the loss must improve to count as an improvement. Default is 1e-4.

TYPE: float DEFAULT: 0.0001

min_lr

Minimum learning rate past which no further reducing is applied. Default is 1e-5.

TYPE: float DEFAULT: 1e-06

verbose

If True, the logger prints when the learning rate decreases or reaches the minimum (INFO level)

TYPE: bool DEFAULT: True

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str,
    called_every: int = 1,
    monitor: str = "train_loss",
    patience: int = 20,
    mode: str = "min",
    gamma: float = 0.5,
    threshold: float = 1e-4,
    min_lr: float = 1e-6,
    verbose: bool = True,
):
    """Initializes the LRSchedulerReduceOnPlateau callback.

    Args:
        on (str): The event to trigger the callback. Default is `train_epoch_end`.
        called_every (int, optional): Frequency of callback calls in terms of iterations.
            Default is 1.
        monitor (str, optional): The metric to monitor (e.g., "val_loss" or "train_loss").
            All metrics returned by optimize step are available to monitor.
            Please add "val_" and "train_" strings at the start of the metric name.
            Default is "train_loss".
        mode (str, optional): Whether to minimize ("min") or maximize ("max") the metric.
            Default is "min".
        patience (int, optional): Number of allowed iterations with no loss
            improvement before reducing the learning rate. Default is 20.
        gamma (float, optional): The decay factor applied to the learning rate. A value
            < 1 reduces the learning rate over time. Default is 0.5.
        threshold (float, optional): Amount by which the loss must improve to count as an
            improvement. Default is 1e-4.
        min_lr (float, optional): Minimum learning rate past which no further reducing
            is applied.  Default is 1e-5.
        verbose (bool, optional): If True, the logger prints when the learning rate
            decreases or reaches the minimum (INFO level)
    """
    super().__init__(on=on, called_every=called_every)
    self.monitor = monitor
    self.mode = mode
    self.patience = patience
    self.gamma = gamma
    self.threshold = threshold
    self.min_lr = min_lr
    self.verbose = verbose

    self.best_value = float("inf")
    self.counter = 0
    self.reached_minimum_lr = False

run_callback(trainer, config, writer)

Reduces the learning rate when the loss reaches a plateau.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
    """
    Reduces the learning rate when the loss reaches a plateau.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter): The writer object for logging.
    """
    if self.reached_minimum_lr:
        pass
    else:
        current_value = trainer.opt_result.metrics.get(self.monitor)
        if current_value is None:
            raise ValueError(
                f"Metric '{self.monitor}' is not available in the trainer's metrics."
            )

        if (self.mode == "min" and current_value + self.threshold < self.best_value) or (
            self.mode == "max" and current_value - self.threshold > self.best_value
        ):
            self.best_value = current_value
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            for param_group in trainer.optimizer.param_groups:
                new_lr = param_group["lr"] * self.gamma
                if new_lr < self.min_lr:
                    param_group["lr"] = self.min_lr
                    self.reached_minimum_lr = True
                    if self.verbose:
                        logger.info(
                            f"Learning rate has reached the minimum learning rate {self.min_lr}"
                        )
                else:
                    param_group["lr"] = new_lr
                    if self.verbose:
                        logger.info(
                            f"Loss has reached a plateau, reducing learning rate to {new_lr}"
                        )
            self.counter = 0

LRSchedulerStepDecay(on, called_every, gamma=0.5)

Bases: Callback

Reduces the learning rate by a factor at regular intervals.

This callback adjusts the learning rate by multiplying it with a decay factor after a specified number of iterations. The learning rate is updated as: lr = lr * gamma

Example Usage in TrainConfig: To use LRSchedulerStepDecay, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import LRSchedulerStepDecay

# Create an instance of the LRSchedulerStepDecay callback
lr_step_decay = LRSchedulerStepDecay(on="train_epoch_end",
                                     called_every=100,
                                     gamma=0.5)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback
    callbacks=[lr_step_decay]
)

Initializes the LRSchedulerStepDecay callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback.

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int

gamma

The decay factor applied to the learning rate. A value < 1 reduces the learning rate over time. Default is 0.5.

TYPE: float DEFAULT: 0.5

Source code in perceptrain/callbacks/callback.py
def __init__(self, on: str, called_every: int, gamma: float = 0.5):
    """Initializes the LRSchedulerStepDecay callback.

    Args:
        on (str): The event to trigger the callback.
        called_every (int): Frequency of callback calls in terms of iterations.
        gamma (float, optional): The decay factor applied to the learning rate.
            A value < 1 reduces the learning rate over time. Default is 0.5.
    """
    super().__init__(on=on, called_every=called_every)
    self.gamma = gamma

run_callback(trainer, config, writer)

Runs the callback to apply step decay to the learning rate.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
    """
    Runs the callback to apply step decay to the learning rate.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter): The writer object for logging.
    """
    for param_group in trainer.optimizer.param_groups:
        param_group["lr"] *= self.gamma

LivePlotMetrics(on, called_every, arrange=None)

Bases: Callback

Callback to follow metrics on screen during training.

It uses livelossplot to update losses and metrics at every call and plot them via matplotlib.

Initializes the callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback.

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int

arrange

How metrics are arranged for the subplots. Each entry is a different group, which will correspond to a different subplot. If None, all metrics will be plotted in a single subplot. Defaults to None.

TYPE: dict[str, dict[str, list[str]] DEFAULT: None

Source code in perceptrain/callbacks/callback.py
def __init__(self, on: str, called_every: int, arrange: dict[str, list[str]] | None = None):
    """Initializes the callback.

    Args:
        on (str): The event to trigger the callback.
        called_every (int): Frequency of callback calls in terms of iterations.
        arrange (dict[str, dict[str, list[str]]): How metrics are arranged for the subplots. Each entry
            is a different group, which will correspond to a different subplot.
            If None, all metrics will be plotted in a single subplot.
            Defaults to None.
    """
    super().__init__(on=on, called_every=called_every)
    self.arrange = arrange

    self.output_mode = llp.get_mode()

    self._first_call = True

LoadCheckpoint(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Bases: Callback

Callback to load a model checkpoint.

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

run_callback(trainer, config, writer)

Loads a model checkpoint.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

RETURNS DESCRIPTION
Any

The result of loading the checkpoint.

TYPE: Any

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Loads a model checkpoint.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.

    Returns:
        Any: The result of loading the checkpoint.
    """
    if trainer.accelerator.rank == 0:
        folder = config.log_folder
        model = trainer.model
        optimizer = trainer.optimizer
        device = trainer.accelerator.execution.log_device
        return load_checkpoint(folder, model, optimizer, device=device)

LogHyperparameters(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Bases: Callback

Callback to log hyperparameters using the writer.

The LogHyperparameters callback can be added to the TrainConfig callbacks as a custom user defined callback.

Example Usage in TrainConfig: To use LogHyperparameters, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import LogHyperparameters

# Create an instance of the LogHyperparameters callback
log_hyper_callback = LogHyperparameters(on = "val_batch_end", called_every = 100)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback that runs every 100 val_batch_end
    callbacks=[log_hyper_callback]
)

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

run_callback(trainer, config, writer)

Logs hyperparameters using the writer.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Logs hyperparameters using the writer.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.
    """
    if trainer.accelerator.rank == 0:
        hyperparams = config.hyperparams
        writer.log_hyperparams(hyperparams)

LogModelTracker(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Bases: Callback

Callback to log the model using the writer.

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

run_callback(trainer, config, writer)

Logs the model using the writer.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Logs the model using the writer.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.
    """
    if trainer.accelerator.rank == 0:
        model = trainer.model
        writer.log_model(
            model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
        )

PrintMetrics(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Bases: Callback

Callback to print metrics using the writer.

The PrintMetrics callback can be added to the TrainConfig callbacks as a custom user defined callback.

Example Usage in TrainConfig: To use PrintMetrics, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import PrintMetrics

# Create an instance of the PrintMetrics callback
print_metrics_callback = PrintMetrics(on = "val_batch_end", called_every = 100)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback that runs every 100 val_batch_end
    callbacks=[print_metrics_callback]
)

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

run_callback(trainer, config, writer)

Prints metrics using the writer.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Prints metrics using the writer.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.
    """
    opt_result = trainer.opt_result
    writer.print_metrics(opt_result)

R3Sampling(initial_dataset, fitness_function, verbose=False, called_every=1)

Bases: Callback

Callback for R3 sampling (https://arxiv.org/abs/2207.02338#).

PARAMETER DESCRIPTION
initial_dataset

The dataset updating according to the R3 scheme.

TYPE: R3Dataset

fitness_function

The function to compute fitness scores for samples. Based on the fitness scores, the samples are retained or released.

TYPE: Callable[[Tensor, Module], Tensor]

verbose

Whether to print the callback's summary. Defaults to False.

TYPE: bool DEFAULT: False

called_every

Every how many events the callback is called. Defaults to 1.

TYPE: int DEFAULT: 1

Notes
  • R3 sampling was developed as a technique for efficient sampling of physics-informed neural networks (PINNs). In this case, the fitness function can be any function of the residuals of the equations

Examples:

Learning an harmonic oscillator with PINNs and R3 sampling. For a well-posed problem, also add the two initial conditions.

    import torch

    m = 1.0
    k = 1.0

    def uniform_1d(n: int):
        return torch.rand(size=(n, 1))

    def harmonic_oscillator(x: torch.Tensor, model: torch.nn.Module) -> torch.Tensor:
        u = model(x)
        dudt = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=torch.ones_like(u),
            create_graph=True,
            retain_graph=True,
        )[0]
        d2udt2 = torch.autograd.grad(
            outputs=dudt,
            inputs=x,
            grad_outputs=torch.ones_like(dudt),
        )[0]
        return m * d2udt2 - kappa * u

    def fitness_function(x: torch.Tensor, model: PINN) -> torch.Tensor:
        return torch.linalg.vector_norm(harmonic_oscillator(x, model.nn), ord=2)

    dataset = R3Dataset(
        proba_dist=uniform_1d,
        n_samples=20,
        release_threshold=1.0,
    )

    callback_r3 = R3Sampling(
        initial_dataset=dataset,
        fitness_function=fitness_function,
        called_every=10,
    )
Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    initial_dataset: R3Dataset,
    fitness_function: Callable[[Tensor, nn.Module], Tensor],
    verbose: bool = False,
    called_every: int = 1,
):
    """Callback for R3 sampling (https://arxiv.org/abs/2207.02338#).

    Args:
        initial_dataset (R3Dataset): The dataset updating according to the R3 scheme.
        fitness_function (Callable[[Tensor, nn.Module], Tensor]): The function to
            compute fitness scores for samples. Based on the fitness scores, the
            samples are retained or released.
        verbose (bool, optional): Whether to print the callback's summary.
            Defaults to False.
        called_every (int, optional): Every how many events the callback is called.
            Defaults to 1.

    Notes:
        - R3 sampling was developed as a technique for efficient sampling of physics-informed
        neural networks (PINNs). In this case, the fitness function can be any function of the
        residuals of the equations

    Examples:
        Learning an harmonic oscillator with PINNs and R3 sampling. For a well-posed problem, also
        add the two initial conditions.

        ```python
            import torch

            m = 1.0
            k = 1.0

            def uniform_1d(n: int):
                return torch.rand(size=(n, 1))

            def harmonic_oscillator(x: torch.Tensor, model: torch.nn.Module) -> torch.Tensor:
                u = model(x)
                dudt = torch.autograd.grad(
                    outputs=u,
                    inputs=x,
                    grad_outputs=torch.ones_like(u),
                    create_graph=True,
                    retain_graph=True,
                )[0]
                d2udt2 = torch.autograd.grad(
                    outputs=dudt,
                    inputs=x,
                    grad_outputs=torch.ones_like(dudt),
                )[0]
                return m * d2udt2 - kappa * u

            def fitness_function(x: torch.Tensor, model: PINN) -> torch.Tensor:
                return torch.linalg.vector_norm(harmonic_oscillator(x, model.nn), ord=2)

            dataset = R3Dataset(
                proba_dist=uniform_1d,
                n_samples=20,
                release_threshold=1.0,
            )

            callback_r3 = R3Sampling(
                initial_dataset=dataset,
                fitness_function=fitness_function,
                called_every=10,
            )
        ```
    """
    self.dataset = initial_dataset
    self.fitness_function = fitness_function
    self.verbose = verbose

    super().__init__(on="train_epoch_start", called_every=called_every)

run_callback(trainer, config, writer)

Runs the callback.

Computes fitness scores for samples and triggers the dataset update.

PARAMETER DESCRIPTION
trainer

The trainer instance.

TYPE: Any

config

The training configuration.

TYPE: TrainConfig

writer

The writer instance.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
    """Runs the callback.

    Computes fitness scores for samples and triggers the dataset update.

    Args:
        trainer (Any): The trainer instance.
        config (TrainConfig): The training configuration.
        writer (BaseWriter): The writer instance.
    """
    # Compute fitness function on all samples
    fitnesses = self.fitness_function(self.dataset.features, trainer.model)

    # R3 update
    self.dataset.update(fitnesses)

    if self.verbose:
        print(
            f"Epoch {trainer.current_epoch};"
            f" num. retained: {self.dataset.n_samples - self.dataset.n_released:5d};"
            f" num. released: {self.dataset.n_released:5d}"
        )

SaveBestCheckpoint(on, called_every)

Bases: SaveCheckpoint

Callback to save the best model checkpoint based on a validation criterion.

Initializes the SaveBestCheckpoint callback.

PARAMETER DESCRIPTION
on

The event to trigger the callback.

TYPE: str

called_every

Frequency of callback calls in terms of iterations.

TYPE: int

Source code in perceptrain/callbacks/callback.py
def __init__(self, on: str, called_every: int):
    """Initializes the SaveBestCheckpoint callback.

    Args:
        on (str): The event to trigger the callback.
        called_every (int): Frequency of callback calls in terms of iterations.
    """
    super().__init__(on=on, called_every=called_every)
    self.best_loss = float("inf")

run_callback(trainer, config, writer)

Saves the checkpoint if the current loss is better than the best loss.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Saves the checkpoint if the current loss is better than the best loss.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.
    """
    if trainer.accelerator.rank == 0:
        opt_result = trainer.opt_result
        if config.validation_criterion and config.validation_criterion(
            opt_result.loss, self.best_loss, config.val_epsilon
        ):
            self.best_loss = opt_result.loss

            folder = config.log_folder
            model = trainer.model
            optimizer = trainer.optimizer
            opt_result = trainer.opt_result
            write_checkpoint(folder, model, optimizer, "best")

SaveCheckpoint(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Bases: Callback

Callback to save a model checkpoint.

The SaveCheckpoint callback can be added to the TrainConfig callbacks as a custom user defined callback.

Example Usage in TrainConfig: To use SaveCheckpoint, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import SaveCheckpoint

# Create an instance of the SaveCheckpoint callback
save_checkpoint_callback = SaveCheckpoint(on = "val_batch_end", called_every = 100)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback that runs every 100 val_batch_end
    callbacks=[save_checkpoint_callback]
)

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

run_callback(trainer, config, writer)

Saves a model checkpoint.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Saves a model checkpoint.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.
    """
    if trainer.accelerator.rank == 0:
        folder = config.log_folder
        model = trainer.model
        optimizer = trainer.optimizer
        opt_result = trainer.opt_result
        write_checkpoint(folder, model, optimizer, opt_result.iteration)

WriteMetrics(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Bases: Callback

Callback to write metrics using the writer.

The WriteMetrics callback can be added to the TrainConfig callbacks as a custom user defined callback.

Example Usage in TrainConfig: To use WriteMetrics, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import WriteMetrics

# Create an instance of the WriteMetrics callback
write_metrics_callback = WriteMetrics(on = "val_batch_end", called_every = 100)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback that runs every 100 val_batch_end
    callbacks=[write_metrics_callback]
)

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

run_callback(trainer, config, writer)

Writes metrics using the writer.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Writes metrics using the writer.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.
    """
    if trainer.accelerator.rank == 0:
        opt_result = trainer.opt_result
        writer.write(opt_result.iteration, opt_result.metrics)

WritePlots(on='idle', called_every=1, callback=None, callback_condition=None, modify_optimize_result=None)

Bases: Callback

Callback to plot metrics using the writer.

The WritePlots callback can be added to the TrainConfig callbacks as a custom user defined callback.

Example Usage in TrainConfig: To use WritePlots, include it in the callbacks list when setting up your TrainConfig:

from perceptrain import TrainConfig
from perceptrain.callbacks import WritePlots

# Create an instance of the WritePlots callback
plot_metrics_callback = WritePlots(on = "val_batch_end", called_every = 100)

config = TrainConfig(
    max_iter=10000,
    # Print metrics every 1000 training epochs
    print_every=1000,
    # Add the custom callback that runs every 100 val_batch_end
    callbacks=[plot_metrics_callback]
)

Source code in perceptrain/callbacks/callback.py
def __init__(
    self,
    on: str | TrainingStage = "idle",
    called_every: int = 1,
    callback: CallbackFunction | None = None,
    callback_condition: CallbackConditionFunction | None = None,
    modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
    if not isinstance(called_every, int):
        raise ValueError("called_every must be a positive integer or 0")

    self.callback: CallbackFunction | None = callback
    self.on: str | TrainingStage = on
    self.called_every: int = called_every
    self.callback_condition = (
        callback_condition if callback_condition else Callback.default_callback
    )

    if isinstance(modify_optimize_result, dict):
        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
            opt_res, modify_optimize_result
        )
    else:
        self.modify_optimize_result = (
            modify_optimize_result
            if modify_optimize_result
            else Callback.modify_opt_res_default
        )

run_callback(trainer, config, writer)

Plots metrics using the writer.

PARAMETER DESCRIPTION
trainer

The training object.

TYPE: Any

config

The configuration object.

TYPE: TrainConfig

writer

The writer object for logging.

TYPE: BaseWriter

Source code in perceptrain/callbacks/callback.py
def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any:
    """Plots metrics using the writer.

    Args:
        trainer (Any): The training object.
        config (TrainConfig): The configuration object.
        writer (BaseWriter ): The writer object for logging.
    """
    if trainer.accelerator.rank == 0:
        opt_result = trainer.opt_result
        plotting_functions = config.plotting_functions
        writer.plot(trainer.model, opt_result.iteration, plotting_functions)

BaseWriter

Bases: ABC

Abstract base class for experiment tracking writers.

METHOD DESCRIPTION
open

Opens the writer and sets up the logging environment.

close

Closes the writer and finalizes any ongoing logging processes.

print_metrics

Prints metrics and loss in a formatted manner.

write

Writes the optimization results to the tracking tool.

log_hyperparams

Logs the hyperparameters to the tracking tool.

plot

Logs model plots using provided plotting functions.

log_model

Logs the model and any relevant information.

close() abstractmethod

Closes the writer and finalizes logging.

Source code in perceptrain/callbacks/writer_registry.py
@abstractmethod
def close(self) -> None:
    """Closes the writer and finalizes logging."""
    raise NotImplementedError("Writers must implement a close method.")

log_hyperparams(hyperparams) abstractmethod

Logs hyperparameters.

PARAMETER DESCRIPTION
hyperparams

A dictionary of hyperparameters to log.

TYPE: dict

Source code in perceptrain/callbacks/writer_registry.py
@abstractmethod
def log_hyperparams(self, hyperparams: dict) -> None:
    """
    Logs hyperparameters.

    Args:
        hyperparams (dict): A dictionary of hyperparameters to log.
    """
    raise NotImplementedError("Writers must implement a log_hyperparams method.")

log_model(model, train_dataloader=None, val_dataloader=None, test_dataloader=None) abstractmethod

Logs the model and associated data.

PARAMETER DESCRIPTION
model

The model to log.

TYPE: Module

train_dataloader

DataLoader for training data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

val_dataloader

DataLoader for validation data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

test_dataloader

DataLoader for testing data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

Source code in perceptrain/callbacks/writer_registry.py
@abstractmethod
def log_model(
    self,
    model: Module,
    train_dataloader: DataLoader | DictDataLoader | None = None,
    val_dataloader: DataLoader | DictDataLoader | None = None,
    test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
    """
    Logs the model and associated data.

    Args:
        model (Module): The model to log.
        train_dataloader (DataLoader | DictDataLoader |  None): DataLoader for training data.
        val_dataloader (DataLoader | DictDataLoader |  None): DataLoader for validation data.
        test_dataloader (DataLoader | DictDataLoader |  None): DataLoader for testing data.
    """
    raise NotImplementedError("Writers must implement a log_model method.")

open(config, iteration=None) abstractmethod

Opens the writer and prepares it for logging.

PARAMETER DESCRIPTION
config

Configuration object containing settings for logging.

TYPE: TrainConfig

iteration

The iteration step to start logging from. Defaults to None.

TYPE: int DEFAULT: None

Source code in perceptrain/callbacks/writer_registry.py
@abstractmethod
def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
    """
    Opens the writer and prepares it for logging.

    Args:
        config: Configuration object containing settings for logging.
        iteration (int, optional): The iteration step to start logging from.
            Defaults to None.
    """
    raise NotImplementedError("Writers must implement an open method.")

plot(model, iteration, plotting_functions) abstractmethod

Logs plots of the model using provided plotting functions.

PARAMETER DESCRIPTION
model

The model to plot.

TYPE: Module

iteration

The current iteration number.

TYPE: int

plotting_functions

Functions used to generate plots.

TYPE: tuple[PlottingFunction, ...]

Source code in perceptrain/callbacks/writer_registry.py
@abstractmethod
def plot(
    self,
    model: Module,
    iteration: int,
    plotting_functions: tuple[PlottingFunction, ...],
) -> None:
    """
    Logs plots of the model using provided plotting functions.

    Args:
        model (Module): The model to plot.
        iteration (int): The current iteration number.
        plotting_functions (tuple[PlottingFunction, ...]): Functions used to
            generate plots.
    """
    raise NotImplementedError("Writers must implement a plot method.")

print_metrics(result)

Prints the metrics and loss in a readable format.

PARAMETER DESCRIPTION
result

The optimization results to display.

TYPE: OptimizeResult

Source code in perceptrain/callbacks/writer_registry.py
def print_metrics(self, result: OptimizeResult) -> None:
    """Prints the metrics and loss in a readable format.

    Args:
        result (OptimizeResult): The optimization results to display.
    """

    # Find the key in result.metrics that contains "loss" (case-insensitive)
    loss_key = next((k for k in result.metrics if "loss" in k.lower()), None)
    initial = f"P {result.rank: >2}|{result.device: <7}| Iteration {result.iteration: >7}| "
    if loss_key:
        loss_value = result.metrics[loss_key]
        msg = initial + f"{loss_key.title()}: {loss_value:.7f} -"
    else:
        msg = initial + f"Loss: None -"
    msg += " ".join([f"{k}: {v:.7f}" for k, v in result.metrics.items() if k != loss_key])
    print(msg)

write(iteration, metrics) abstractmethod

Logs the results of the current iteration.

PARAMETER DESCRIPTION
iteration

The current training iteration.

TYPE: int

metrics

A dictionary of metrics to log, where keys are metric names and values are the corresponding metric values.

TYPE: dict

Source code in perceptrain/callbacks/writer_registry.py
@abstractmethod
def write(self, iteration: int, metrics: dict) -> None:
    """
    Logs the results of the current iteration.

    Args:
        iteration (int): The current training iteration.
        metrics (dict): A dictionary of metrics to log, where keys are metric names
                        and values are the corresponding metric values.
    """
    raise NotImplementedError("Writers must implement a write method.")

MLFlowWriter()

Bases: BaseWriter

Writer for logging to MLflow.

ATTRIBUTE DESCRIPTION
run

The active MLflow run.

TYPE: Run

mlflow

The MLflow module.

TYPE: ModuleType

Source code in perceptrain/callbacks/writer_registry.py
def __init__(self) -> None:
    try:
        from mlflow.entities import Run
    except ImportError:
        raise ImportError(
            "mlflow is not installed. Please install perceptrain with the mlflow feature: "
            "`pip install perceptrain[mlflow]`."
        )

    self.run: Run
    self.mlflow: ModuleType

close()

Closes the MLflow run.

Source code in perceptrain/callbacks/writer_registry.py
def close(self) -> None:
    """Closes the MLflow run."""
    if self.run:
        self.mlflow.end_run()

get_signature_from_dataloader(model, dataloader)

Infers the signature of the model based on the input data from the dataloader.

PARAMETER DESCRIPTION
model

The model to use for inference.

TYPE: Module

dataloader

DataLoader for model inputs.

TYPE: DataLoader | DictDataLoader | None

RETURNS DESCRIPTION
Any

Optional[Any]: The inferred signature, if available.

Source code in perceptrain/callbacks/writer_registry.py
def get_signature_from_dataloader(
    self, model: Module, dataloader: DataLoader | DictDataLoader | None
) -> Any:
    """
    Infers the signature of the model based on the input data from the dataloader.

    Args:
        model (Module): The model to use for inference.
        dataloader (DataLoader | DictDataLoader |  None): DataLoader for model inputs.

    Returns:
        Optional[Any]: The inferred signature, if available.
    """
    from mlflow.models import infer_signature

    if dataloader is None:
        return None

    xs: InputData
    xs, *_ = next(iter(dataloader))
    preds = model(xs)

    if isinstance(xs, Tensor):
        xs = xs.detach().cpu().numpy()
        preds = preds.detach().cpu().numpy()
        return infer_signature(xs, preds)

    return None

log_hyperparams(hyperparams)

Logs hyperparameters to MLflow.

PARAMETER DESCRIPTION
hyperparams

A dictionary of hyperparameters to log.

TYPE: dict

Source code in perceptrain/callbacks/writer_registry.py
def log_hyperparams(self, hyperparams: dict) -> None:
    """
    Logs hyperparameters to MLflow.

    Args:
        hyperparams (dict): A dictionary of hyperparameters to log.
    """
    if self.mlflow:
        self.mlflow.log_params(hyperparams)
    else:
        raise RuntimeError(
            "The writer is not initialized."
            "Please call the 'writer.open()' method before writing"
        )

log_model(model, train_dataloader=None, val_dataloader=None, test_dataloader=None)

Logs the model and its signature to MLflow using the provided data loaders.

PARAMETER DESCRIPTION
model

The model to log.

TYPE: Module

train_dataloader

DataLoader for training data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

val_dataloader

DataLoader for validation data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

test_dataloader

DataLoader for testing data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

Source code in perceptrain/callbacks/writer_registry.py
def log_model(
    self,
    model: Module,
    train_dataloader: DataLoader | DictDataLoader | None = None,
    val_dataloader: DataLoader | DictDataLoader | None = None,
    test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
    """
    Logs the model and its signature to MLflow using the provided data loaders.

    Args:
        model (Module): The model to log.
        train_dataloader (DataLoader | DictDataLoader |  None): DataLoader for training data.
        val_dataloader (DataLoader | DictDataLoader |  None): DataLoader for validation data.
        test_dataloader (DataLoader | DictDataLoader |  None): DataLoader for testing data.
    """
    if not self.mlflow:
        raise RuntimeError(
            "The writer is not initialized."
            "Please call the 'writer.open()' method before writing"
        )

    signatures = self.get_signature_from_dataloader(model, train_dataloader)
    self.mlflow.pytorch.log_model(model, artifact_path="model", signature=signatures)

open(config, iteration=None)

Opens the MLflow writer and initializes an MLflow run.

PARAMETER DESCRIPTION
config

Configuration object containing settings for logging.

TYPE: TrainConfig

iteration

The iteration step to start logging from. Defaults to None.

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
mlflow

The MLflow module instance.

TYPE: ModuleType | None

Source code in perceptrain/callbacks/writer_registry.py
def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType | None:
    """
    Opens the MLflow writer and initializes an MLflow run.

    Args:
        config: Configuration object containing settings for logging.
        iteration (int, optional): The iteration step to start logging from.
            Defaults to None.

    Returns:
        mlflow: The MLflow module instance.
    """
    import mlflow

    self.mlflow = mlflow
    tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
    experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4()))
    run_name = os.getenv("MLFLOW_RUN_NAME", str(uuid4()))

    if self.mlflow:
        self.mlflow.set_tracking_uri(tracking_uri)

        # Create or get the experiment
        exp_filter_string = f"name = '{experiment_name}'"
        experiments = self.mlflow.search_experiments(filter_string=exp_filter_string)
        if not experiments:
            self.mlflow.create_experiment(name=experiment_name)

        self.mlflow.set_experiment(experiment_name)
        self.run = self.mlflow.start_run(run_name=run_name, nested=False)

    return self.mlflow

plot(model, iteration, plotting_functions)

Logs plots of the model using provided plotting functions.

PARAMETER DESCRIPTION
model

The model to plot.

TYPE: Module

iteration

The current iteration number.

TYPE: int

plotting_functions

Functions used to generate plots.

TYPE: tuple[PlottingFunction, ...]

Source code in perceptrain/callbacks/writer_registry.py
def plot(
    self,
    model: Module,
    iteration: int,
    plotting_functions: tuple[PlottingFunction, ...],
) -> None:
    """
    Logs plots of the model using provided plotting functions.

    Args:
        model (Module): The model to plot.
        iteration (int): The current iteration number.
        plotting_functions (tuple[PlottingFunction, ...]): Functions used
            to generate plots.
    """
    if self.mlflow:
        for pf in plotting_functions:
            descr, fig = pf(model, iteration)
            self.mlflow.log_figure(fig, descr)
    else:
        raise RuntimeError(
            "The writer is not initialized."
            "Please call the 'writer.open()' method before writing"
        )

write(iteration, metrics)

Logs the results of the current iteration to MLflow.

PARAMETER DESCRIPTION
iteration

The current training iteration.

TYPE: int

metrics

A dictionary of metrics to log, where keys are metric names and values are the corresponding metric values.

TYPE: dict

Source code in perceptrain/callbacks/writer_registry.py
def write(self, iteration: int, metrics: dict) -> None:
    """
    Logs the results of the current iteration to MLflow.

    Args:
        iteration (int): The current training iteration.
        metrics (dict): A dictionary of metrics to log, where keys are metric names
                        and values are the corresponding metric values.
    """
    if self.mlflow:
        self.mlflow.log_metrics(metrics, step=iteration)
    else:
        raise RuntimeError(
            "The writer is not initialized."
            "Please call the 'writer.open()' method before writing."
        )

TensorBoardWriter()

Bases: BaseWriter

Writer for logging to TensorBoard.

ATTRIBUTE DESCRIPTION
writer

The TensorBoard SummaryWriter instance.

TYPE: SummaryWriter

Source code in perceptrain/callbacks/writer_registry.py
def __init__(self) -> None:
    self.writer = None

close()

Closes the TensorBoard writer.

Source code in perceptrain/callbacks/writer_registry.py
def close(self) -> None:
    """Closes the TensorBoard writer."""
    if self.writer:
        self.writer.close()

log_hyperparams(hyperparams)

Logs hyperparameters to TensorBoard.

PARAMETER DESCRIPTION
hyperparams

A dictionary of hyperparameters to log.

TYPE: dict

Source code in perceptrain/callbacks/writer_registry.py
def log_hyperparams(self, hyperparams: dict) -> None:
    """
    Logs hyperparameters to TensorBoard.

    Args:
        hyperparams (dict): A dictionary of hyperparameters to log.
    """
    if self.writer:
        self.writer.add_hparams(hyperparams, {})
    else:
        raise RuntimeError(
            "The writer is not initialized."
            "Please call the 'writer.open()' method before writing"
        )

log_model(model, train_dataloader=None, val_dataloader=None, test_dataloader=None)

Logs the model.

Currently not supported by TensorBoard.

PARAMETER DESCRIPTION
model

The model to log.

TYPE: Module

train_dataloader

DataLoader for training data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

val_dataloader

DataLoader for validation data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

test_dataloader

DataLoader for testing data.

TYPE: DataLoader | DictDataLoader | None DEFAULT: None

Source code in perceptrain/callbacks/writer_registry.py
def log_model(
    self,
    model: Module,
    train_dataloader: DataLoader | DictDataLoader | None = None,
    val_dataloader: DataLoader | DictDataLoader | None = None,
    test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
    """
    Logs the model.

    Currently not supported by TensorBoard.

    Args:
        model (Module): The model to log.
        train_dataloader (DataLoader | DictDataLoader |  None): DataLoader for training data.
        val_dataloader (DataLoader | DictDataLoader |  None): DataLoader for validation data.
        test_dataloader (DataLoader | DictDataLoader |  None): DataLoader for testing data.
    """
    logger.warning("Model logging is not supported by tensorboard. No model will be logged.")

open(config, iteration=None)

Opens the TensorBoard writer.

PARAMETER DESCRIPTION
config

Configuration object containing settings for logging.

TYPE: TrainConfig

iteration

The iteration step to start logging from. Defaults to None.

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
SummaryWriter

The initialized TensorBoard writer.

TYPE: SummaryWriter

Source code in perceptrain/callbacks/writer_registry.py
def open(self, config: TrainConfig, iteration: int | None = None) -> SummaryWriter:
    """
    Opens the TensorBoard writer.

    Args:
        config: Configuration object containing settings for logging.
        iteration (int, optional): The iteration step to start logging from.
            Defaults to None.

    Returns:
        SummaryWriter: The initialized TensorBoard writer.
    """
    log_dir = str(config.log_folder)
    purge_step = iteration if isinstance(iteration, int) else None
    self.writer = SummaryWriter(log_dir=log_dir, purge_step=purge_step)
    return self.writer

plot(model, iteration, plotting_functions)

Logs plots of the model using provided plotting functions.

PARAMETER DESCRIPTION
model

The model to plot.

TYPE: Module

iteration

The current iteration number.

TYPE: int

plotting_functions

Functions used to generate plots.

TYPE: tuple[PlottingFunction, ...]

Source code in perceptrain/callbacks/writer_registry.py
def plot(
    self,
    model: Module,
    iteration: int,
    plotting_functions: tuple[PlottingFunction, ...],
) -> None:
    """
    Logs plots of the model using provided plotting functions.

    Args:
        model (Module): The model to plot.
        iteration (int): The current iteration number.
        plotting_functions (tuple[PlottingFunction, ...]): Functions used
            to generate plots.
    """
    if self.writer:
        for pf in plotting_functions:
            descr, fig = pf(model, iteration)
            self.writer.add_figure(descr, fig, global_step=iteration)
    else:
        raise RuntimeError(
            "The writer is not initialized."
            "Please call the 'writer.open()' method before writing"
        )

write(iteration, metrics)

Logs the results of the current iteration to TensorBoard.

PARAMETER DESCRIPTION
iteration

The current training iteration.

TYPE: int

metrics

A dictionary of metrics to log, where keys are metric names and values are the corresponding metric values.

TYPE: dict

Source code in perceptrain/callbacks/writer_registry.py
def write(self, iteration: int, metrics: dict) -> None:
    """
    Logs the results of the current iteration to TensorBoard.

    Args:
        iteration (int): The current training iteration.
        metrics (dict): A dictionary of metrics to log, where keys are metric names
                        and values are the corresponding metric values.
    """
    if self.writer:
        for key, value in metrics.items():
            self.writer.add_scalar(key, value, iteration)
    else:
        raise RuntimeError(
            "The writer is not initialized."
            "Please call the 'writer.open()' method before writing."
        )

get_writer(tracking_tool)

Factory method to get the appropriate writer based on the tracking tool.

PARAMETER DESCRIPTION
tracking_tool

The experiment tracking tool to use.

TYPE: ExperimentTrackingTool

RETURNS DESCRIPTION
BaseWriter

An instance of the appropriate writer.

TYPE: BaseWriter

Source code in perceptrain/callbacks/writer_registry.py
def get_writer(tracking_tool: ExperimentTrackingTool) -> BaseWriter:
    """Factory method to get the appropriate writer based on the tracking tool.

    Args:
        tracking_tool (ExperimentTrackingTool): The experiment tracking tool to use.

    Returns:
        BaseWriter: An instance of the appropriate writer.
    """
    writer_class = WRITER_REGISTRY.get(tracking_tool)
    if writer_class:
        return writer_class()
    else:
        raise ValueError(f"Unsupported tracking tool: {tracking_tool}")