Skip to content

losses

canari_ml.models.losses

canari_ml.models.losses.LOSS_REGISTRY = {'l1': nn.L1Loss, 'mse': nn.MSELoss, 'huber': nn.HuberLoss} module-attribute

canari_ml.models.losses.WeightedLoss(loss_type='mse', **kwargs)

Bases: Module

Weighted loss.

Compute loss weighted by masking.

Source code in src/canari_ml/models/losses.py
def __init__(self, loss_type="mse", **kwargs) -> None:
    super().__init__()

    if loss_type not in LOSS_REGISTRY:
        raise ValueError(f"Unsupported loss type: {loss_type}")

    self.loss_fn = LOSS_REGISTRY[loss_type.lower()](reduction="none", **kwargs)

canari_ml.models.losses.WeightedLoss.reduction instance-attribute

canari_ml.models.losses.WeightedLoss.loss_fn = LOSS_REGISTRY[loss_type.lower()](reduction='none', **kwargs) instance-attribute

canari_ml.models.losses.WeightedLoss.forward(predictions, targets, sample_weights)

Source code in src/canari_ml/models/losses.py
def forward(self, predictions, targets, sample_weights):
    loss = self.loss_fn(predictions, targets) * sample_weights

    return loss.mean()