Skip to content

Training Models

Overview

This section guides you through using the canari_ml train command to train the default UNet model using preprocessed ERA5 data. Similar to previous sections, you can override default settings using CLI arguments or YAML configuration files.

The training process involves:

  1. Using preprocessed training datasets generated by canari_ml preprocess train (Refer to the preprocess page).
  2. Configuring model architecture, training parameters, and evaluation metrics.
  3. Running the training loop with options for monitoring and early stopping.

Getting Started

Prerequisites

  • Preprocess your ERA5 data following the steps in preprocess specifically for training.
  • The training command assumes that your preprocessing has generated a Zarr-formatted dataset and corresponding JSON config file, ready for training.

Usage

The canari_ml train command is used to train models using preprocessed ERA5 data. The command supports both direct execution and overrides via CLI arguments or custom YAML config files, or both.

Basic Usage

To see all available configuration options, run:

$ canari_ml train --help
train is powered by Hydra.

== Configuration groups ==
Compose your configuration from those groups (group=option)

callbacks: default, early_stopping, model_checkpoint
common: default
hydra_config: predict, train
logger: csv, tensorboard, wandb
model: default, unet
paths: default, download, plot, postprocess, predict, preprocess, train
plot: default, ua700
postprocess: default, netcdf, plot_ua700
predict: default
preprocess: default
profiler: pytorch
train: default
trainer: default


== Config ==
Override anything in the config (foo.bar=value)

train:
  dataset: ???
  name: ???
  seed: 42
  epochs: 50
  workers: 4
  batch_size: 4
  shuffling: true
  wandb_group: unet
  wandb_project: CANARI
verbose: true
paths:
  train: outputs/${train.name}/training/
source_dataset_id: era5
model:
  model_name: unet
  network:
    _target_: canari_ml.models.models.UNet
    _partial_: true
    filter_size: 3
    n_filters_factor: 1.0
    n_output_classes: 1
  litmodule:
    _target_: canari_ml.models.lightning_modules.LitUNet
    _partial_: true
    criterion:
      _target_: canari_ml.models.losses.WeightedLoss
      loss_type: mse
    learning_rate: 0.0001
    metrics:
    - canari_ml.models.metrics.MAE
    - canari_ml.models.metrics.MSE
    - canari_ml.models.metrics.RMSE
    enable_leadtime_metrics: false
callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    dirpath: ${hydra:runtime.output_dir}/checkpoints
    filename: epoch={epoch}-${callbacks.model_checkpoint.monitor}={${callbacks.model_checkpoint.monitor}:.4f}
    monitor: val_rmse
    verbose: true
    save_last: true
    save_top_k: 1
    mode: min
    auto_insert_metric_name: false
    save_weights_only: false
    every_n_train_steps: null
    train_time_interval: null
    every_n_epochs: null
    save_on_train_epoch_end: null
    enable_version_counter: false
  early_stopping:
    _target_: lightning.pytorch.callbacks.EarlyStopping
    monitor: val_rmse
    min_delta: 0.0
    patience: 10
    mode: min
    strict: true
    check_finite: true
    stopping_threshold: null
    divergence_threshold: null
    check_on_train_epoch_end: null
    log_rank_zero_only: null
logger:
  _target_: pytorch_lightning.loggers.TensorBoardLogger
  save_dir: tb_logs
  name: ${train.name}_${train.seed}
trainer:
  _target_: lightning.pytorch.Trainer
  _partial_: true
  accelerator: auto
  devices: 1
  precision: 16-mixed
  log_every_n_steps: 5
  max_epochs: ${train.epochs}
  num_sanity_val_steps: 0
  deterministic: true
  fast_dev_run: false
  logger: ${logger}
  callbacks: ${callbacks}


Powered by Hydra (https://hydra.cc)
Use --hydra-help to view Hydra specific help

This will display the default configuration and all modifiable parameters.

Next Steps

After training your model, you can:

  1. Evaluate its performance using validation metrics.
  2. Generate predictions using the trained model.
Todo

Show how to run ensembles.