Training Models¶
Overview¶
This section guides you through using the canari_ml train command to train the default UNet model using preprocessed ERA5 data. Similar to previous sections, you can override default settings using CLI arguments or YAML configuration files.
The training process involves:
- Using preprocessed training datasets generated by
canari_ml preprocess train(Refer to the preprocess page). - Configuring model architecture, training parameters, and evaluation metrics.
- Running the training loop with options for monitoring and early stopping.
Getting Started¶
Prerequisites¶
- Preprocess your ERA5 data following the steps in preprocess specifically for training.
- The training command assumes that your preprocessing has generated a Zarr-formatted dataset and corresponding JSON config file, ready for training.
Usage¶
The canari_ml train command is used to train models using preprocessed ERA5 data. The command supports both direct execution and overrides via CLI arguments or custom YAML config files, or both.
Basic Usage¶
To see all available configuration options, run:
train is powered by Hydra.
== Configuration groups ==
Compose your configuration from those groups (group=option)
callbacks: default, early_stopping, model_checkpoint
common: default
hydra_config: predict, train
logger: csv, tensorboard, wandb
model: default, unet
paths: default, download, plot, postprocess, predict, preprocess, train
plot: default, ua700
postprocess: default, netcdf, plot_ua700
predict: default
preprocess: default
profiler: pytorch
train: default
trainer: default
== Config ==
Override anything in the config (foo.bar=value)
train:
dataset: ???
name: ???
seed: 42
epochs: 50
workers: 4
batch_size: 4
shuffling: true
wandb_group: unet
wandb_project: CANARI
verbose: true
paths:
train: outputs/${train.name}/training/
source_dataset_id: era5
model:
model_name: unet
network:
_target_: canari_ml.models.models.UNet
_partial_: true
filter_size: 3
n_filters_factor: 1.0
n_output_classes: 1
litmodule:
_target_: canari_ml.models.lightning_modules.LitUNet
_partial_: true
criterion:
_target_: canari_ml.models.losses.WeightedLoss
loss_type: mse
learning_rate: 0.0001
metrics:
- canari_ml.models.metrics.MAE
- canari_ml.models.metrics.MSE
- canari_ml.models.metrics.RMSE
enable_leadtime_metrics: false
callbacks:
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
filename: epoch={epoch}-${callbacks.model_checkpoint.monitor}={${callbacks.model_checkpoint.monitor}:.4f}
monitor: val_rmse
verbose: true
save_last: true
save_top_k: 1
mode: min
auto_insert_metric_name: false
save_weights_only: false
every_n_train_steps: null
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
enable_version_counter: false
early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: val_rmse
min_delta: 0.0
patience: 10
mode: min
strict: true
check_finite: true
stopping_threshold: null
divergence_threshold: null
check_on_train_epoch_end: null
log_rank_zero_only: null
logger:
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: tb_logs
name: ${train.name}_${train.seed}
trainer:
_target_: lightning.pytorch.Trainer
_partial_: true
accelerator: auto
devices: 1
precision: 16-mixed
log_every_n_steps: 5
max_epochs: ${train.epochs}
num_sanity_val_steps: 0
deterministic: true
fast_dev_run: false
logger: ${logger}
callbacks: ${callbacks}
Powered by Hydra (https://hydra.cc)
Use --hydra-help to view Hydra specific help
This will display the default configuration and all modifiable parameters.
Next Steps¶
After training your model, you can:
- Evaluate its performance using validation metrics.
- Generate predictions using the trained model.
Todo
Show how to run ensembles.