Skip to content

predict

canari_ml.hydra.predict

canari_ml.hydra.predict.logger = logging.getLogger(__name__) module-attribute

canari_ml.hydra.predict.predict_run(cfg)

Run prediction based on the provided HYDRA configuration.

This function loads a Hydra configuration, and generates predictions from a trained model.

Parameters:

Name Type Description Default
cfg DictConfig

Hydra auto-loaded configuration.

required
Source code in src/canari_ml/hydra/predict.py
@hydra.main(
    version_base=None,
    config_path=str(Path(__file__).parent / "../conf"),
    config_name="predict",
)
def predict_run(cfg: DictConfig) -> None:
    """
    Run prediction based on the provided HYDRA configuration.

    This function loads a Hydra configuration, and generates
    predictions from a trained model.

    Args:
        cfg: Hydra auto-loaded configuration.
    """
    from canari_ml.hydra.utils import print_omega_config
    print_omega_config(cfg)

    network = HYDRAPytorchNetwork(cfg, run_type="predict")
    network.predict()

canari_ml.hydra.predict.main()

Source code in src/canari_ml/hydra/predict.py
def main():
    OmegaConf.register_new_resolver("set_preprocess_type", lambda x: "predict")
    predict_run() # type: ignore