Skip to content

train

canari_ml.hydra.train

canari_ml.hydra.train.logger = logging.getLogger(__name__) module-attribute

canari_ml.hydra.train.train_run(cfg)

Run training based on the provided HYDRA configuration.

This function loads a Hydra configuration, and trains the model.

Parameters:

Name Type Description Default
cfg DictConfig

Hydra auto-loaded configuration.

required
Source code in src/canari_ml/hydra/train.py
@hydra.main(
    version_base=None,
    config_path=str(Path(__file__).parent / "../conf"),
    config_name="train",
)
def train_run(cfg: DictConfig) -> None:
    """
    Run training based on the provided HYDRA configuration.

    This function loads a Hydra configuration, and trains the model.

    Args:
        cfg: Hydra auto-loaded configuration.
    """
    from canari_ml.hydra.utils import print_omega_config
    print_omega_config(cfg)

    from canari_ml.models.networks.pytorch import HYDRAPytorchNetwork
    network = HYDRAPytorchNetwork(cfg, run_type="train")
    network.train()

canari_ml.hydra.train.main()

Source code in src/canari_ml/hydra/train.py
def main():
    OmegaConf.register_new_resolver("set_preprocess_type", lambda: "train")
    preprocess_register_resolvers()
    train_run() # type: ignore