Skip to content

metrics

canari_ml.models.metrics

canari_ml.models.metrics.BaseMetric(leadtimes_to_evaluate=None)

Bases: Metric

Base class for all metrics.

Reference: https://lightning.ai/docs/torchmetrics/stable/pages/implement.html

Parameters:

Name Type Description Default
leadtimes_to_evaluate list

A list of leadtimes to consider e.g., [0, 1, 2, 3, 4, 5] to consider first n days/months (i.e. leadtime) in accuracy computation e.g., [0] to only look at the first day's accuracy e.g., [5] to only look at the sixth day's accuracy

None
Source code in src/canari_ml/models/metrics.py
def __init__(self, leadtimes_to_evaluate: list = None):
    """Weighted metric for multiple leadtimes.

    Args:
        leadtimes_to_evaluate: A list of leadtimes to consider
            e.g., [0, 1, 2, 3, 4, 5] to consider first n days/months (i.e. leadtime)
                in accuracy computation
            e.g., [0] to only look at the first day's accuracy
            e.g., [5] to only look at the sixth day's accuracy
    """
    super().__init__()
    self.leadtimes_to_evaluate = (
        leadtimes_to_evaluate if leadtimes_to_evaluate is not None else slice(None)
    )

canari_ml.models.metrics.BaseMetric.leadtimes_to_evaluate = leadtimes_to_evaluate if leadtimes_to_evaluate is not None else slice(None) instance-attribute

canari_ml.models.metrics.MAE(*args, **kwargs)

Bases: BaseMetric

Weighted MAE metric for use at multiple leadtimes.

Source code in src/canari_ml/models/metrics.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.add_state(
        "sum_weighted_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum"
    )
    self.add_state("total_weight", default=torch.tensor(0.0), dist_reduce_fx="sum")

canari_ml.models.metrics.MAE.update(predictions, targets, sample_weight, **kwargs)

Update state with predictions and targets.

Source code in src/canari_ml/models/metrics.py
def update(
    self,
    predictions: torch.Tensor,
    targets: torch.Tensor,
    sample_weight: torch.Tensor,
    **kwargs,
) -> None:
    """Update state with predictions and targets."""
    predictions, targets, sample_weight = self._select_leadtimes(
        predictions, targets, sample_weight
    )

    abs_error = (predictions - targets).abs()
    weighted_sum_abs_error = (abs_error * sample_weight).sum()
    total_weight = sample_weight.sum()

    self.sum_weighted_abs_error += weighted_sum_abs_error
    self.total_weight += total_weight

canari_ml.models.metrics.MAE.compute()

Source code in src/canari_ml/models/metrics.py
def compute(self) -> torch.Tensor:
    return self.sum_weighted_abs_error / self.total_weight

canari_ml.models.metrics.MSE(*args, **kwargs)

Bases: BaseMetric

Weighted MSE metric for use at multiple leadtimes.

Source code in src/canari_ml/models/metrics.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.add_state(
        "sum_weighted_squared_error",
        default=torch.tensor(0.0),
        dist_reduce_fx="sum",
    )
    self.add_state("total_weight", default=torch.tensor(0.0), dist_reduce_fx="sum")

canari_ml.models.metrics.MSE.update(predictions, targets, sample_weight, **kwargs)

Update state with predictions and targets.

Source code in src/canari_ml/models/metrics.py
def update(
    self,
    predictions: torch.Tensor,
    targets: torch.Tensor,
    sample_weight: torch.Tensor,
    **kwargs,
) -> None:
    """Update state with predictions and targets."""
    predictions, targets, sample_weight = self._select_leadtimes(
        predictions, targets, sample_weight
    )

    squared_error = (predictions - targets) ** 2.0
    weighted_sum_squared_error = (squared_error * sample_weight).sum()
    total_weight = sample_weight.sum()

    self.sum_weighted_squared_error += weighted_sum_squared_error
    self.total_weight += total_weight

canari_ml.models.metrics.MSE.compute()

Source code in src/canari_ml/models/metrics.py
def compute(self) -> torch.Tensor:
    return self.sum_weighted_squared_error / self.total_weight

canari_ml.models.metrics.RMSE(*args, **kwargs)

Bases: MSE

Weighted Root Mean Squared Error for use at multiple leadtimes., computed as sqrt of Weighted MSE.

Source code in src/canari_ml/models/metrics.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

canari_ml.models.metrics.RMSE.compute()

Source code in src/canari_ml/models/metrics.py
def compute(self) -> torch.Tensor:
    return super().compute().sqrt()