Skip to content

dataloader

canari_ml.data.dataloader

canari_ml.data.dataloader.ZarrDataset(root_path, zarr_name, train_split=True)

Bases: Dataset

Parameters:

Name Type Description Default
root_path str

Path to the directory containing 'train.zarr', 'val.zarr', 'test.zarr'.

required
zarr_name str

Name of the Zarr file to load (e.g., 'train.zarr', 'val.zarr', 'test.zarr').

required
train_split bool

Whether to load the training split Defaults to True.

True
Source code in src/canari_ml/data/dataloader.py
def __init__(
    self, root_path: str, zarr_name: str, train_split: bool = True
) -> None:
    """
    Initialise the dataset from a directory containing Zarr files.

    Args:
        root_path: Path to the directory containing 'train.zarr', 'val.zarr',
            'test.zarr'.
        zarr_name: Name of the Zarr file to load (e.g., 'train.zarr', 'val.zarr',
            'test.zarr').
        train_split: Whether to load the training split
            Defaults to True.
    """
    self.root_path = root_path
    self.train_split = train_split

    zarr_path = os.path.join(root_path, zarr_name)

    self.store = zarr.open(zarr_path)
    self.x_array = self.store["x"]
    self.y_array = self.store["y"]
    self.sw_array = self.store.get("sample_weights", None)

canari_ml.data.dataloader.ZarrDataset.root_path = root_path instance-attribute

canari_ml.data.dataloader.ZarrDataset.train_split = train_split instance-attribute

canari_ml.data.dataloader.ZarrDataset.store = zarr.open(zarr_path) instance-attribute

canari_ml.data.dataloader.ZarrDataset.x_array = self.store['x'] instance-attribute

canari_ml.data.dataloader.ZarrDataset.y_array = self.store['y'] instance-attribute

canari_ml.data.dataloader.ZarrDataset.sw_array = self.store.get('sample_weights', None) instance-attribute

canari_ml.data.dataloader.CANARIMLDataSetTorch(configuration_path, *args, batch_size=4, path=os.path.join('.', 'network_datasets'), shuffling=False, **kwargs)

Bases: IceNetDataSet

Source code in src/canari_ml/data/dataloader.py
def __init__(
    self,
    configuration_path,
    *args,
    batch_size=4,
    path=os.path.join(".", "network_datasets"),
    shuffling=False,
    **kwargs,
):
    super().__init__(
        *args, configuration_path=configuration_path, path=path, **kwargs
    )

    self._config = {}
    self._configuration_path = configuration_path
    self._load_configuration(configuration_path)

    self._batch_size = batch_size
    self._lead_time = self._config["lead_time"]
    self._num_channels = self._config["num_channels"]
    self._shape = tuple(self._config["shape"])
    self._shuffling = shuffling
    self.hemi = "south" if self._config["south"] else "north" if self._config["north"] else None

    if self._config.get("dataset_path") and os.path.exists(
        self._config["dataset_path"]
    ):
        logging.warning(
            "Will generate cache dataset"
        )
    else:
        logging.warning(
            "Running in configuration only mode, Zarr cache files are not being "
            "generated for this dataset"
        )

canari_ml.data.dataloader.CANARIMLDataSetTorch.hemi = 'south' if self._config['south'] else 'north' if self._config['north'] else None instance-attribute

canari_ml.data.dataloader.CANARIMLDataSetTorch.get_data_loaders(num_workers=4, ratio=None)

Source code in src/canari_ml/data/dataloader.py
def get_data_loaders(self, num_workers=4, ratio=None):
    persistent_workers = True if num_workers else False

    root_path = self._config["dataset_path"]

    train_dataset = ZarrDataset(root_path=os.path.join(root_path, "train"), zarr_name="train.zarr")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=self._batch_size,
        shuffle=self._shuffling,
        num_workers=num_workers,
        persistent_workers=persistent_workers,
        pin_memory=True,  # For faster transfer to GPU if using one
    )

    val_dataset = ZarrDataset(root_path=os.path.join(root_path, "val"), zarr_name="val.zarr")
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=self._batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=persistent_workers,
        pin_memory=True,
    )

    test_dataset = ZarrDataset(root_path=os.path.join(root_path, "test"), zarr_name="test.zarr")
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=self._batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=persistent_workers,
        pin_memory=True,
    )
    return train_dataloader, val_dataloader, test_dataloader

canari_ml.data.dataloader.CANARIMLDataSetTorch.get_data_loader(lead_time=None, generate_workers=None, base_path=os.path.join('.', 'network_datasets'), dummy=False)

Create an instance of the CANARIDataLoader class.

Parameters:

Name Type Description Default
lead_time optional

The number of forecast steps to be used by the data loader. If not provided, defaults to the value specified in the configuration file.

None
generate_workers optional

An integer representing number of workers to use for parallel processing with Dask. If not provided, defaults to the value specified in the configuration file.

None

Returns:

Type Description
object

An instance of the SerialLoader class configured with the specified parameters.

Source code in src/canari_ml/data/dataloader.py
def get_data_loader(self,
                    lead_time: object = None,
                    generate_workers: object = None,
                    base_path: str = os.path.join(".", "network_datasets"),
                    dummy: bool = False,
                    ) -> object:
    """Create an instance of the CANARIDataLoader class.

    Args:
        lead_time (optional): The number of forecast steps to be used by the data loader.
            If not provided, defaults to the value specified in the configuration file.
        generate_workers (optional): An integer representing number of workers to use for
            parallel processing with Dask. If not provided, defaults to the value specified in
            the configuration file.

    Returns:
        An instance of the SerialLoader class configured with the specified parameters.
    """
    if lead_time is None:
        lead_time = self._config["lead_time"]
    if generate_workers is None:
        generate_workers = self._config["generate_workers"]
    loader = CanariMLDataLoaderFactory().create_data_loader(
        "serial",  # This will load the `SerialLoader` class.
        self.loader_config,
        self.identifier,
        lag_time=self._config["lag_time"],
        lead_time=lead_time,
        generate_workers=generate_workers,
        dataset_config_path=os.path.dirname(self._configuration_path),
        loss_weight_days=self._config["loss_weight_days"],
        output_batch_size=self._config["output_batch_size"],
        var_lag_override=self._config["var_lag_override"],
        base_path=base_path,
        dummy=dummy,
    )
    return loader