Bases: ModelCheckpoint
A custom ModelCheckpoint that saves only when monitored metric improves.
This checkpoint class tracks the best score of a monitored metric and
saves model checkpoints only when there's improvement. It evaluates
the metric after each training epoch and compares it to the previous
best score, saving the checkpoint if an improvement is detected.
Attributes:
| Name |
Type |
Description |
best_score |
|
The best metric value encountered so far.
|
Source code in src/canari_ml/lightning/checkpoints.py
| def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.best_score = None
|
canari_ml.lightning.checkpoints.ModelCheckpointOnImprovement.best_score = None
instance-attribute
canari_ml.lightning.checkpoints.ModelCheckpointOnImprovement.on_train_epoch_end(trainer, pl_module)
Source code in src/canari_ml/lightning/checkpoints.py
| def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
monitor_candidates = self._monitor_candidates(trainer)
# try:
# current_score = monitor_candidates[self.monitor]
# except KeyError as e:
valid_metrics = trainer.callback_metrics
if self.monitor not in valid_metrics:
raise KeyError(
f"`{self.monitor}` is not a metric being monitored, select from: "
f"{valid_metrics.keys()}"
)
else:
current_score = monitor_candidates[self.monitor]
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
logging.debug("Metric candidates for monitoring:", valid_metrics)
# Check if metric's best score has improved.
if self.best_score is None or monitor_op(current_score, self.best_score):
logging.info(
f"Checkpoint saved at epoch {trainer.current_epoch} with "
f"{self.monitor}: {current_score:.4f}"
)
self.best_score = current_score
# Only save checkpoint if score has improved
super().on_train_epoch_end(trainer, pl_module)
else:
logging.info(
f"No improvement in {self.monitor} at epoch {trainer.current_epoch}:"
f" {current_score:.4f} (Best: {self.best_score:.4f})"
)
|