Skip to content

network_dataset

canari_ml.data.network_dataset

canari_ml.data.network_dataset.logger = logging.getLogger(__name__) module-attribute

canari_ml.data.network_dataset.SplittingMixin

Read train, val, test datasets from tfrecord protocol buffer files.

Split and shuffle data if specified as well.

Example

This mixin is not to be used directly, but to give an idea of its use:

Initialise SplittingMixin

split_dataset = SplittingMixin()

canari_ml.data.network_dataset.SplittingMixin.batch_size property

The dataset's batch size.

canari_ml.data.network_dataset.SplittingMixin.dtype property

The dataset's data type.

canari_ml.data.network_dataset.SplittingMixin.lead_time property

The number of time steps to forecast.

canari_ml.data.network_dataset.SplittingMixin.num_channels property

The number of channels in dataset.

canari_ml.data.network_dataset.SplittingMixin.shape property

The shape of dataset.

canari_ml.data.network_dataset.SplittingMixin.shuffling property

A flag for whether training dataset(s) are marked to be shuffled.

canari_ml.data.network_dataset.IceNetDataSet(configuration_path, *args, batch_size=4, path=os.path.join('.', 'network_datasets'), shuffling=False, **kwargs)

Bases: SplittingMixin, DataCollection

Initialises and configures a dataset.

It loads a JSON configuration file, updates the _config attribute with the result, creates a data loader, and methods to access the dataset.

Attributes:

Name Type Description
_config

A dict used to store configuration loaded from JSON file.

_configuration_path

The path to the JSON configuration file.

_batch_size

The batch size for the data loader.

_counts

A dict with number of elements in train, val, test.

_dtype

The type of the dataset.

_loader_config

The path to the data loader configuration file.

_generate_workers

An integer representing number of workers for parallel processing with Dask.

_lead_time

An integer representing number of days to predict for.

_num_channels

An integer representing number of channels (input variables) in the dataset.

_shape

The shape of the dataset.

_shuffling

A flag indicating whether to shuffle the data or not.

Parameters:

Name Type Description Default
configuration_path str

The path to the JSON configuration file.

required
*args

Additional positional arguments.

()
batch_size optional

How many samples to load per batch. Defaults to 4.

4
path optional

The path to the directory where the processed tfrecord protocol buffer files will be stored. Defaults to './network_datasets'.

join('.', 'network_datasets')
shuffling optional

Flag indicating whether to shuffle the data. Defaults to False.

False
*args

Additional keyword arguments.

()
Source code in src/canari_ml/data/network_dataset.py
def __init__(self,
             configuration_path: str,
             *args,
             batch_size: int = 4,
             path: str = os.path.join(".", "network_datasets"),
             shuffling: bool = False,
             **kwargs) -> None:
    """Initialises an instance of the IceNetDataSet class.

    Args:
        configuration_path: The path to the JSON configuration file.
        *args: Additional positional arguments.
        batch_size (optional): How many samples to load per batch. Defaults to 4.
        path (optional): The path to the directory where the processed tfrecord
            protocol buffer files will be stored. Defaults to './network_datasets'.
        shuffling (optional): Flag indicating whether to shuffle the data.
            Defaults to False.
        *args: Additional keyword arguments.
    """

    self._config = dict()
    self._configuration_path = configuration_path
    self._load_configuration(configuration_path)

    super().__init__(*args,
                     identifier=self._config["identifier"],
                     base_path=path,
                     **kwargs)

    # TODO: code smell - loading config twice because not using DataCollection
    self._config = dict()
    self._load_configuration(configuration_path)
    self._batch_size = batch_size
    self._counts = self._config["counts"]
    self._dtype = getattr(np, self._config["dtype"])
    self._loader_config = self._config["loader_config"]
    self._generate_workers = self._config["generate_workers"]
    self._lead_time = self._config["lead_time"]
    self._num_channels = self._config["num_channels"]
    self._shape = tuple(self._config["shape"])
    self._shuffling = shuffling

    path_attr = "dataset_path"

    # Check JSON config has attribute for path to zarr datasets, and
    #   that the path exists.
    if self._config[path_attr] and \
            os.path.exists(self._config[path_attr]):
        pass
    else:
        logging.warning("Running in configuration only mode, Zarr datasets"
                        "were not generated for this dataset")

canari_ml.data.network_dataset.IceNetDataSet.loader_config property

The path to the JSON loader configuration file stored in the dataset config file.

canari_ml.data.network_dataset.IceNetDataSet.channels property

The list of channels (variable names) specified in the dataset config file.

canari_ml.data.network_dataset.IceNetDataSet.counts property

A dict with number of elements in train, val, test in the config file.