def main():
prog_name = os.path.basename(sys.argv[0])
parser = argparse.ArgumentParser(prog=prog_name, add_help=True)
subparsers = parser.add_subparsers(dest="command", help="Available subcommands")
# Add subcommands without help (let Hydra handle it)
subparsers.add_parser("download", add_help=False)
# Pre-processing commands
preprocess_parser = subparsers.add_parser("preprocess", add_help=True)
preprocess_subparsers = preprocess_parser.add_subparsers(dest="subcommand")
preprocess_subparsers.add_parser("train", add_help=False)
preprocess_subparsers.add_parser("predict", add_help=False)
# Train/predict commands
subparsers.add_parser("train", add_help=False)
subparsers.add_parser("predict", add_help=False)
# Post-processing commands
postprocess_parser = subparsers.add_parser("postprocess", add_help=False)
# postprocess_subparsers = postprocess_parser.add_subparsers(dest="subcommand")
# postprocess_subcommands = ["netcdf"]
# for cmd in postprocess_subcommands:
# postprocess_subparsers.add_parser(cmd, add_help=False)
# Plotting commands
plot_parser = subparsers.add_parser("plot", add_help=False)
# plot_subparsers = plot_parser.add_subparsers(dest="subcommand")
# plot_subcommands = ["ua700"]
# for cmd in plot_subcommands:
# plot_subparsers.add_parser(cmd, add_help=False)
# Let argparse only parse known args
args, unknown_args = parser.parse_known_args()
# Reconstruct `sys.argv` for Hydra (removing the command/subcommand parts)
sys.argv = [prog_name] + unknown_args
if args.command == "download":
from canari_ml.hydra import download
download.main()
elif args.command == "preprocess":
from canari_ml.hydra import preprocess
# Takes in `args.subcommand` of `train` or `predict`
preprocess.main(preprocess_type=args.subcommand)
elif args.command == "train":
from canari_ml.hydra import train
train.main()
elif args.command == "predict":
from canari_ml.hydra import predict
predict.main()
elif args.command == "postprocess":
from canari_ml.hydra import postprocess
postprocess.main()
# if args.subcommand in postprocess_subcommands:
# getattr(postprocess, f"out_{args.subcommand}")()
elif args.command == "plot":
from canari_ml.hydra import plot
plot.main()