Skip to content

base

canari_ml.data.loaders.base

canari_ml.data.loaders.base.DATE_FORMAT = '%Y-%m-%d' module-attribute

canari_ml.data.loaders.base.CanariMLBaseDataLoader(loader_configuration, identifier, *args, dates_override=None, dry=False, generate_workers=8, lag_time=None, lead_time=None, loss_weight_days=True, output_batch_size=32, base_path=os.path.join('.', 'network_datasets'), config_path='.', pickup=False, var_lag_override=None, **kwargs)

Bases: DataCollection

Loader base class for all data loaders used in this codebase. Based on IceNet v0.4.0_dev

Attributes:

Name Type Description
loader_configuration

This is a string that will be converted to a dictionary and passed as an argument to the appropriate loader class

identifier

The identifier for the current loader

var_lag

The number of previous months/days for which to use features

generate_workers bool

Whether to generate workers or not. Only used if loader_configuration is set to "tf_dataset". If it is false, then this loader class will return a generator of tf.Dataset's, one for each worker. If true, then it will return a single tf.Dataset. This should be set to False when using the distributed strategy and True otherwise

loss_weight_days bool

The number of months/days used to calculate loss weights

n_forecast_days bool

Number of months/days ahead we want to predict

output_batch_size bool

The batch size that is passed to the model

base_path bool

Path where cached Zarr files can be stored

config_path bool

The path to the dataset config file

var_lag_override bool
Source code in src/canari_ml/data/loaders/base.py
def __init__(
    self,
    loader_configuration: str,
    identifier: str,
    *args,
    dates_override: object = None,
    dry: bool = False,
    generate_workers: int = 8,
    lag_time: int = None,
    lead_time: int = None,
    loss_weight_days: bool = True,
    output_batch_size: int = 32,
    base_path: str = os.path.join(".", "network_datasets"),
    config_path: str = ".",
    pickup: bool = False,
    var_lag_override: object = None,
    **kwargs,
):
    super().__init__(*args, identifier=identifier, base_path=base_path, **kwargs)

    self._channels = dict()
    self._channel_files = dict()

    self._configuration_path = loader_configuration
    self._dataset_config_path = config_path
    self._dates_override = dates_override
    self._config = dict()
    self._dry = dry
    self._loss_weight_days = loss_weight_days
    self._meta_channels = []
    self._missing_dates = []
    self._output_batch_size = output_batch_size
    self._pickup = pickup
    self._trend_steps = dict()
    self._workers = generate_workers

    self._load_configuration(loader_configuration)

    # TODO: we assume that ground truth is the first dataset in the ordering
    ground_truth_id, ground_truth_cfg = list(self._config["sources"].items())[0]
    processor = get_processor_from_source(ground_truth_id, ground_truth_cfg)
    ds_config = get_dataset_config_implementation(processor.dataset_config)
    # TODO: this is smelly, it suggests there is missing logic between Processor and
    #  NormalisingChannelProcessor to handle suffixes
    ref_ds = processor.get_dataset(
        ["{}_abs".format(el) for el in processor.abs_vars]
    )
    ref_da = getattr(ref_ds.isel(time=0), list(ref_ds.data_vars)[0])

    # Things that come from preprocessing by default
    self._dtype = ref_da.dtype
    # TODO: we shouldn't ideally need this but we do need a concept of location for masks
    self._ds_config = processor.dataset_config
    self._frequency_attr = ds_config.frequency.attribute
    self._lag_time = lag_time if lag_time is not None else processor.lag_time
    self._lead_time = lead_time if lead_time is not None else processor.lead_time
    self._north = ds_config.location.north
    self._shape = ref_da.shape
    self._south = ds_config.location.south
    self._var_lag_override = dict() if not var_lag_override else var_lag_override

    self._construct_channels()

    self._missing_dates = []

canari_ml.data.loaders.base.CanariMLBaseDataLoader.channel_names property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.config property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.dates_override property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.north property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.num_channels property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.pickup property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.south property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.workers property

canari_ml.data.loaders.base.CanariMLBaseDataLoader.get_data_var_folder(var_name, append=None, missing_error=False)

Returns the path for a specific data variable.

Appends additional folders to the path if specified in the append parameter.

Parameters:

Name Type Description Default
var_name str

The data variable.

required
append object

Additional folders to append to the path. Defaults to None.

None
missing_error optional

Flag to specify if missing directories should be treated as an error. Defaults to False.

False

Returns:

Type Description
PathLike

The path for the specific data variable.

Raises:

Type Description
FileNotFoundError

If missing_error is True and a directory is missing.

Source code in src/canari_ml/data/loaders/base.py
def get_data_var_folder(
    self, var_name: str, append: object = None, missing_error: bool = False
) -> os.PathLike:
    """Returns the path for a specific data variable.

    Appends additional folders to the path if specified in the `append` parameter.

    Args:
        var_name: The data variable.
        append: Additional folders to append to the path.
                Defaults to None.
        missing_error (optional): Flag to specify if missing directories should be
                treated as an error. Defaults to False.

    Returns:
        The path for the specific data variable.

    Raises:
        FileNotFoundError: If `missing_error` is True and a directory is missing.
    """

    if not append:
        append = []

    data_var_path = os.path.join(self.path, *[var_name, *append])

    if not os.path.exists(data_var_path):
        if not missing_error:
            os.makedirs(data_var_path, exist_ok=True)
        else:
            raise OSError(
                "Directory {} is missing and this is flagged as an error!".format(
                    data_var_path
                )
            )

    return data_var_path

canari_ml.data.loaders.base.CanariMLBaseDataLoader.write_dataset_config_only()

Source code in src/canari_ml/data/loaders/base.py
def write_dataset_config_only(self):
    """ """
    splits = self._config["sources"][list(self._config["sources"].keys())[0]][
        "splits"
    ]
    counts = {el: 0 for el in splits}

    logging.info("Writing dataset configuration without data generation")

    # FIXME: cloned mechanism from generate() - do we need to treat these as
    #  sets that might have missing data for fringe cases?
    for dataset in splits:
        forecast_dates = sorted(
            list(
                set(
                    [
                        dt.datetime.strptime(s, DATE_FORMAT).date()
                        for identity in self._config["sources"].keys()
                        for s in self._config["sources"][identity]["splits"][
                            dataset
                        ]
                    ]
                )
            )
        )

        logging.info(
            "{} {} dates in total, NOT generating cache data.".format(
                len(forecast_dates), dataset
            )
        )
        counts[dataset] += len(forecast_dates)

    self._write_dataset_config(counts, network_dataset=False)

canari_ml.data.loaders.base.CanariMLBaseDataLoader.generate_sample(date, prediction=False) abstractmethod

:param date: :param prediction: :return:

Source code in src/canari_ml/data/loaders/base.py
@abstractmethod
def generate_sample(self, date: object, prediction: bool = False):
    """

    :param date:
    :param prediction:
    :return:
    """
    pass

canari_ml.data.loaders.base.CanariMLBaseDataLoader.get_sample_files()

:param date: :return:

Source code in src/canari_ml/data/loaders/base.py
def get_sample_files(self) -> object:
    """

    :param date:
    :return:
    """
    # FIXME: is this not just the same as _channel_files now?
    # FIXME: still experimental code, move to multiple implementations
    # FIXME: CLEAN THIS ALL UP ONCE VERIFIED FOR local/shared STORAGE!
    var_files = dict()

    for var_name, num_channels in self._channels.items():
        var_file = self._get_var_file(var_name)

        if not var_file:
            raise RuntimeError("No file returned for {}".format(var_name))

        if var_name not in var_files:
            var_files[var_name] = var_file
        elif var_file != var_files[var_name]:
            raise RuntimeError(
                "Differing files? {} {} vs {}".format(
                    var_name, var_file, var_files[var_name]
                )
            )

    return var_files