Skip to content

cli

canari_ml.hydra.cli

canari_ml.hydra.cli.main()

Source code in src/canari_ml/hydra/cli.py
def main():
    prog_name = os.path.basename(sys.argv[0])

    parser = argparse.ArgumentParser(prog=prog_name, add_help=True)
    subparsers = parser.add_subparsers(dest="command", help="Available subcommands")

    # Add subcommands without help (let Hydra handle it)
    subparsers.add_parser("download", add_help=False)

    # Pre-processing commands
    preprocess_parser = subparsers.add_parser("preprocess", add_help=True)
    preprocess_subparsers = preprocess_parser.add_subparsers(dest="subcommand")
    preprocess_subparsers.add_parser("train", add_help=False)
    preprocess_subparsers.add_parser("predict", add_help=False)

    # Train/predict commands
    subparsers.add_parser("train", add_help=False)
    subparsers.add_parser("predict", add_help=False)

    # Post-processing commands
    postprocess_parser = subparsers.add_parser("postprocess", add_help=False)
    # postprocess_subparsers = postprocess_parser.add_subparsers(dest="subcommand")
    # postprocess_subcommands = ["netcdf"]
    # for cmd in postprocess_subcommands:
    #     postprocess_subparsers.add_parser(cmd, add_help=False)

    # Plotting commands
    plot_parser = subparsers.add_parser("plot", add_help=False)
    # plot_subparsers = plot_parser.add_subparsers(dest="subcommand")
    # plot_subcommands = ["ua700"]
    # for cmd in plot_subcommands:
    #     plot_subparsers.add_parser(cmd, add_help=False)

    # Let argparse only parse known args
    args, unknown_args = parser.parse_known_args()

    # Reconstruct `sys.argv` for Hydra (removing the command/subcommand parts)
    sys.argv = [prog_name] + unknown_args
    if args.command == "download":
        from canari_ml.hydra import download
        download.main()
    elif args.command == "preprocess":
        from canari_ml.hydra import preprocess
        # Takes in `args.subcommand` of `train` or `predict`
        preprocess.main(preprocess_type=args.subcommand)
    elif args.command == "train":
        from canari_ml.hydra import train
        train.main()
    elif args.command == "predict":
        from canari_ml.hydra import predict
        predict.main()
    elif args.command == "postprocess":
        from canari_ml.hydra import postprocess
        postprocess.main()
        # if args.subcommand in postprocess_subcommands:
        #     getattr(postprocess, f"out_{args.subcommand}")()
    elif args.command == "plot":
        from canari_ml.hydra import plot
        plot.main()