lightning_modules
canari_ml.models.lightning_modules
¶
canari_ml.models.lightning_modules.BaseLightningModule(model, criterion, learning_rate, metrics, enable_leadtime_metrics=True, **kwargs)
¶
Bases: LightningModule
Base class for all Canari ML models using PyTorch Lightning.
This module inherits from pytorch_lightning.LightningModule and provides the basic
functionality required for training, validating, and testing Canari ML models. It
also includes support for saving hyperparameters to checkpoints and recording
metrics during training and validation.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Module
|
The PyTorch model being wrapped. |
criterion |
callable
|
The loss function used during training and validation. |
learning_rate |
float
|
The learning rate used for optimisation. |
metrics |
Iterable[callable]
|
An iterable of callable objects representing the metrics to be recorded during training and validation. |
enable_leadtime_metrics |
bool
|
Flag indicating whether to enable lead-time related metrics. Defaults to True. |
n_output_classes |
int
|
The number of output classes in the model. |
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.BaseLightningModule.model = model
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.criterion = criterion
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.learning_rate = learning_rate
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.metrics = metrics
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.enable_leadtime_metrics = enable_leadtime_metrics
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.n_output_classes = model.n_output_classes
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.metrics_history = defaultdict(list)
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.train_metrics = metric_collection.clone(prefix='train_')
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.val_metrics = metric_collection.clone(prefix='val_')
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.test_metrics = metric_collection.clone(prefix='test_')
instance-attribute
¶
canari_ml.models.lightning_modules.BaseLightningModule.forward(x)
¶
Implement forward function.
Applies the model to the input tensor x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Inputs to the model. Expected shape is
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output of the model with shape |
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.BaseLightningModule.on_save_checkpoint(checkpoint)
¶
Override PyTorch Lightning's default on_save_checkpoint method to add custom data.
This method adds the name of the class and the path to the Lightning module to the checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint
|
dict
|
The checkpoint dictionary to which additional data will be added. |
required |
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet(*args, **kwargs)
¶
Bases: BaseLightningModule
A LightningModule wrapping the :class:UNet implementation of IceNet.
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.training_step(batch, batch_idx)
¶
Perform a pass through a batch of training data.
This method implements the core training loop for a single batch of data. It takes the input, output, and sample weights from the provided batch, passes the inputs through the model to obtain predictions, computes the pixel-weighted loss using the provided criterion, and updates any relevant metrics. The computed loss is then returned for use in backpropagation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict
|
A dictionary containing 'x', 'y', and 'sample_weights' keys with their respective values representing input data, target output data, and sample weights. |
required |
batch_idx
|
int
|
Index of the current batch. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary containing the computed loss for this batch of data. This is used in backpropagation to update the model's parameters. |
.. note::
The method uses pixel-weighted loss by manually reducing it, following the
approach outlined
here <https://discuss.pytorch.org/t/unet-pixel-wise-weighted-loss-function/46689/5>_.
It also logs the computed loss and metrics for use in monitoring training
progress.
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.validation_step(batch, batch_idx)
¶
Perform a pass through a batch of validation data.
This method implements the core validation loop for a single batch of data.
The methodology is the same as training_step. The computed loss is logged
for use in monitoring validation progress.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict
|
A dictionary containing 'x', 'y', and 'sample_weights' keys with their respective values representing input data, target output data, and sample weights. |
required |
batch_idx
|
int
|
Index of the current batch. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
A dictionary containing the computed loss for this batch of data. This is used in logging to monitor validation progress. |
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.test_step(batch, batch_idx)
¶
Perform a pass through a batch of test data.
This method implements the core testing loop for a single batch of data.
The methodology is the same as training_step. The computed loss is logged
for use in monitoring test progress.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict
|
A dictionary containing 'x', 'y', and 'sample_weights' keys with their respective values representing input data, target output data, and sample weights. |
required |
batch_idx
|
int
|
Index of the current batch. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary containing the computed loss for this batch of data. This is used in logging to monitor test progress. |
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.on_train_epoch_end()
¶
Perform actions at the end of each training epoch.
This method is called by PyTorch Lightning at the end of each training epoch. It computes and stores the average loss for the completed epoch, then resets the metrics computed during individual training steps in preparation for the next epoch.
.. note::
The implementation follows the migration guide from Lightning v1.5 to v2.0,
as outlined here <https://github.com/Lightning-AI/pytorch-lightning/pull/16520>_.
It also references `this issue <https://github.com/Lightning-AI/pytorch-lightning/issues/13147#issuecomment-1138975446>`_
for accessing logged results.
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.on_validation_epoch_end()
¶
Perform actions at the end of each validation epoch.
This method is called by PyTorch Lightning at the end of each validation epoch. It computes and stores the average loss for the completed epoch, then updates, stores, and resets the metrics computed during individual validation steps in preparation for the next epoch.
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.on_test_epoch_end()
¶
Perform actions at the end of each test epoch.
This method is called by PyTorch Lightning at the end of each test epoch. It logs and resets the metrics computed during individual test steps in preparation for the next epoch.
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.predict_step(batch, batch_idx)
¶
Generate predictions for a given input batch.
This method is called by PyTorch Lightning during prediction to generate model outputs for the provided input batch. It returns the model's predictions based on the input data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict
|
A dictionary containing the input ('x') and output ('y') tensors, as well as any additional relevant information like 'sample_weights'. |
required |
batch_idx
|
int
|
The index of the current batch in the dataloader. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
y_hat |
Tensor
|
The model's predictions for the given input batch. |
Source code in src/canari_ml/models/lightning_modules.py
canari_ml.models.lightning_modules.LitUNet.configure_optimizers()
¶
Configure and return the optimizer and learning rate scheduler.
This method is called by PyTorch Lightning to initialise the optimizer and learning rate scheduler used for training the model. It returns a dictionary containing both the optimizer and the lr_scheduler.
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary containing the optimizer and lr_scheduler. - optimizer (torch.optim.optimizer.Optimizer): The optimizer instance used to update model parameters during training. - lr_scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler instance, which adjusts the learning rate over time based on specified criteria. |