canari_ml train CLI Help¶
Run the following command to get the help information for canari_ml train command:
train is powered by Hydra.
== Configuration groups ==
Compose your configuration from those groups (group=option)
callbacks: default, early_stopping, model_checkpoint
common: default
hydra_config: predict, train
logger: csv, tensorboard, wandb
model: default, unet
paths: default, download, plot, postprocess, predict, preprocess, train
plot: default, ua700
postprocess: default, netcdf, plot_ua700
predict: default
preprocess: default
profiler: pytorch
train: default
trainer: default
== Config ==
Override anything in the config (foo.bar=value)
train:
dataset: ???
name: ???
seed: 42
epochs: 50
workers: 4
batch_size: 4
shuffling: true
wandb_group: unet
wandb_project: CANARI
verbose: true
paths:
train: outputs/${train.name}/training/
source_dataset_id: era5
model:
model_name: unet
network:
_target_: canari_ml.models.models.UNet
_partial_: true
filter_size: 3
n_filters_factor: 1.0
n_output_classes: 1
litmodule:
_target_: canari_ml.models.lightning_modules.LitUNet
_partial_: true
criterion:
_target_: canari_ml.models.losses.WeightedLoss
loss_type: mse
learning_rate: 0.0001
metrics:
- canari_ml.models.metrics.MAE
- canari_ml.models.metrics.MSE
- canari_ml.models.metrics.RMSE
enable_leadtime_metrics: false
callbacks:
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
filename: epoch={epoch}-${callbacks.model_checkpoint.monitor}={${callbacks.model_checkpoint.monitor}:.4f}
monitor: val_rmse
verbose: true
save_last: true
save_top_k: 1
mode: min
auto_insert_metric_name: false
save_weights_only: false
every_n_train_steps: null
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
enable_version_counter: false
early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: val_rmse
min_delta: 0.0
patience: 10
mode: min
strict: true
check_finite: true
stopping_threshold: null
divergence_threshold: null
check_on_train_epoch_end: null
log_rank_zero_only: null
logger:
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: tb_logs
name: ${train.name}_${train.seed}
trainer:
_target_: lightning.pytorch.Trainer
_partial_: true
accelerator: auto
devices: 1
precision: 16-mixed
log_every_n_steps: 5
max_epochs: ${train.epochs}
num_sanity_val_steps: 0
deterministic: true
fast_dev_run: false
logger: ${logger}
callbacks: ${callbacks}
Powered by Hydra (https://hydra.cc)
Use --hydra-help to view Hydra specific help