Skip to content

losses

Methods for Loss Computation.

denoisplit_loss(model_outputs, targets, config, gaussian_likelihood=None, noise_model_likelihood=None) #

Loss function for DenoiSplit.

Parameters:

Name Type Description Default
model_outputs tuple[Tensor, dict[str, Any]]

Tuple containing the model predictions (shape is (B, target_ch, [Z], Y, X)) and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).

required
targets Tensor

The target image used to compute the reconstruction loss. Shape is (B, target_ch, [Z], Y, X).

required
config LVAELossConfig

The config for loss function containing all loss hyperparameters.

required
gaussian_likelihood GaussianLikelihood

The Gaussian likelihood object.

None
noise_model_likelihood NoiseModelLikelihood

The noise model likelihood object.

None

Returns:

Name Type Description
output Optional[dict[str, Tensor]]

A dictionary containing the overall loss ["loss"], the reconstruction loss ["reconstruction_loss"], and the KL divergence loss ["kl_loss"].

Source code in src/careamics/losses/lvae/losses.py
def denoisplit_loss(
    model_outputs: tuple[torch.Tensor, dict[str, Any]],
    targets: torch.Tensor,
    config: LVAELossConfig,
    gaussian_likelihood: GaussianLikelihood | None = None,
    noise_model_likelihood: NoiseModelLikelihood | None = None,
) -> dict[str, torch.Tensor] | None:
    """Loss function for DenoiSplit.

    Parameters
    ----------
    model_outputs : tuple[torch.Tensor, dict[str, Any]]
        Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
        and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
    targets : torch.Tensor
        The target image used to compute the reconstruction loss. Shape is
        (B, `target_ch`, [Z], Y, X).
    config : LVAELossConfig
        The config for loss function containing all loss hyperparameters.
    gaussian_likelihood : GaussianLikelihood
        The Gaussian likelihood object.
    noise_model_likelihood : NoiseModelLikelihood
        The noise model likelihood object.

    Returns
    -------
    output : Optional[dict[str, torch.Tensor]]
        A dictionary containing the overall loss `["loss"]`, the reconstruction loss
        `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
    """
    assert noise_model_likelihood is not None

    predictions, td_data = model_outputs

    # Reconstruction loss computation
    recons_loss = config.reconstruction_weight * get_reconstruction_loss(
        reconstruction=predictions,
        target=targets,
        likelihood_obj=noise_model_likelihood,
    )
    if torch.isnan(recons_loss).any():
        recons_loss = 0.0

    # KL loss computation
    kl_weight = get_kl_weight(
        config.kl_params.annealing,
        config.kl_params.start,
        config.kl_params.annealtime,
        config.kl_weight,
        config.kl_params.current_epoch,
    )
    kl_loss = (
        _get_kl_divergence_loss_denoisplit(
            topdown_data=td_data,
            img_shape=targets.shape[2:],
            kl_type=config.kl_params.loss_type,
        )
        * kl_weight
    )

    net_loss = recons_loss + kl_loss
    output = {
        "loss": net_loss,
        "reconstruction_loss": (
            recons_loss.detach()
            if isinstance(recons_loss, torch.Tensor)
            else recons_loss
        ),
        "kl_loss": kl_loss.detach(),
    }
    # https://github.com/openai/vdvae/blob/main/train.py#L26
    if torch.isnan(net_loss).any():
        return None

    return output

denoisplit_musplit_loss(model_outputs, targets, config, gaussian_likelihood, noise_model_likelihood) #

Loss function for DenoiSplit.

Parameters:

Name Type Description Default
model_outputs tuple[Tensor, dict[str, Any]]

Tuple containing the model predictions (shape is (B, target_ch, [Z], Y, X)) and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).

required
targets Tensor

The target image used to compute the reconstruction loss. Shape is (B, target_ch, [Z], Y, X).

required
config LVAELossConfig

The config for loss function containing all loss hyperparameters.

required
gaussian_likelihood GaussianLikelihood

The Gaussian likelihood object.

required
noise_model_likelihood NoiseModelLikelihood

The noise model likelihood object.

required

Returns:

Name Type Description
output Optional[dict[str, Tensor]]

A dictionary containing the overall loss ["loss"], the reconstruction loss ["reconstruction_loss"], and the KL divergence loss ["kl_loss"].

Source code in src/careamics/losses/lvae/losses.py
def denoisplit_musplit_loss(
    model_outputs: tuple[torch.Tensor, dict[str, Any]],
    targets: torch.Tensor,
    config: LVAELossConfig,
    gaussian_likelihood: GaussianLikelihood,
    noise_model_likelihood: NoiseModelLikelihood,
) -> dict[str, torch.Tensor] | None:
    """Loss function for DenoiSplit.

    Parameters
    ----------
    model_outputs : tuple[torch.Tensor, dict[str, Any]]
        Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
        and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
    targets : torch.Tensor
        The target image used to compute the reconstruction loss. Shape is
        (B, `target_ch`, [Z], Y, X).
    config : LVAELossConfig
        The config for loss function containing all loss hyperparameters.
    gaussian_likelihood : GaussianLikelihood
        The Gaussian likelihood object.
    noise_model_likelihood : NoiseModelLikelihood
        The noise model likelihood object.

    Returns
    -------
    output : Optional[dict[str, torch.Tensor]]
        A dictionary containing the overall loss `["loss"]`, the reconstruction loss
        `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
    """
    predictions, td_data = model_outputs

    # Reconstruction loss computation
    recons_loss = _reconstruction_loss_musplit_denoisplit(
        predictions=predictions,
        targets=targets,
        nm_likelihood=noise_model_likelihood,
        gaussian_likelihood=gaussian_likelihood,
        nm_weight=config.denoisplit_weight,
        gaussian_weight=config.musplit_weight,
    )
    if torch.isnan(recons_loss).any():
        recons_loss = 0.0

    # KL loss computation
    # NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
    # The different naming comes from `top_down_pass()` method in the LadderVAE.
    denoisplit_kl = _get_kl_divergence_loss_denoisplit(
        topdown_data=td_data,
        img_shape=targets.shape[2:],
        kl_type=config.kl_params.loss_type,
    )
    musplit_kl = _get_kl_divergence_loss_musplit(
        topdown_data=td_data,
        img_shape=targets.shape[2:],
        kl_type=config.kl_params.loss_type,
    )
    kl_loss = (
        config.denoisplit_weight * denoisplit_kl + config.musplit_weight * musplit_kl
    )
    # TODO `kl_weight` is hardcoded (???)
    kl_loss = config.kl_weight * kl_loss

    net_loss = recons_loss + kl_loss
    output = {
        "loss": net_loss,
        "reconstruction_loss": (
            recons_loss.detach()
            if isinstance(recons_loss, torch.Tensor)
            else recons_loss
        ),
        "kl_loss": kl_loss.detach(),
    }
    # https://github.com/openai/vdvae/blob/main/train.py#L26
    if torch.isnan(net_loss).any():
        return None

    return output

get_kl_divergence_loss(kl_type, topdown_data, rescaling, aggregation, free_bits_coeff, img_shape=None) #

Compute the KL divergence loss.

NOTE: Description of rescaling methods: - If "latent_dim", the KL-loss values are rescaled w.r.t. the latent space dimensions (spatial + number of channels, i.e., (C, [Z], Y, X)). In this way they have the same magnitude across layers. - If "image_dim", the KL-loss values are rescaled w.r.t. the input image spatial dimensions. In this way, the lower layers have a larger KL-loss value compared to the higher layers, since the latent space and hence the KL tensor has more entries. Specifically, at hierarchy i, the total KL loss is larger by a factor (128/i**2).

NOTE: the type of aggregation determines the magnitude of the KL-loss. Clearly, "sum" aggregation results in a larger KL-loss value compared to "mean" by a factor of n_layers.

NOTE: recall that sample-wise KL is obtained by summing over all dimensions, including Z. Also recall that in current 3D implementation of LVAE, no downsampling is done on Z. Therefore, to avoid emphasizing KL loss too much, we divide it by the Z dimension of input image in every case.

Parameters:

Name Type Description Default
kl_type Literal['kl', 'kl_restricted']

The type of KL divergence loss to compute.

required
topdown_data dict[str, Tensor]

A dictionary containing information computed for each layer during the top-down pass. The dictionary must include the following keys: - "kl": The KL-loss values for each layer. Shape of each tensor is (B,). - "z": The sampled latents for each layer. Shape of each tensor is (B, layers, z_dims[i], H, W).

required
rescaling Literal['latent_dim', 'image_dim']

The rescaling method used for the KL-loss values. If "latent_dim", the KL-loss values are rescaled w.r.t. the latent space dimensions (spatial + number of channels, i.e., (C, [Z], Y, X)). If "image_dim", the KL-loss values are rescaled w.r.t. the input image spatial dimensions.

required
aggregation Literal['mean', 'sum']

The aggregation method used to combine the KL-loss values across layers. If "mean", the KL-loss values are averaged across layers. If "sum", the KL-loss values are summed across layers.

required
free_bits_coeff float

The free bits coefficient used for the KL-loss computation.

required
img_shape Optional[tuple[int]]

The shape of the input image to the LVAE model. Shape is ([Z], Y, X).

None

Returns:

Name Type Description
kl_loss Tensor

The KL divergence loss. Shape is (1, ).

Source code in src/careamics/losses/lvae/losses.py
def get_kl_divergence_loss(
    kl_type: Literal["kl", "kl_restricted"],
    topdown_data: dict[str, torch.Tensor],
    rescaling: Literal["latent_dim", "image_dim"],
    aggregation: Literal["mean", "sum"],
    free_bits_coeff: float,
    img_shape: tuple[int] | None = None,
) -> torch.Tensor:
    """Compute the KL divergence loss.

    NOTE: Description of `rescaling` methods:
    - If "latent_dim", the KL-loss values are rescaled w.r.t. the latent space
    dimensions (spatial + number of channels, i.e., (C, [Z], Y, X)). In this way they
    have the same magnitude across layers.
    - If "image_dim", the KL-loss values are rescaled w.r.t. the input image spatial
    dimensions. In this way, the lower layers have a larger KL-loss value compared to
    the higher layers, since the latent space and hence the KL tensor has more entries.
    Specifically, at hierarchy `i`, the total KL loss is larger by a factor (128/i**2).

    NOTE: the type of `aggregation` determines the magnitude of the KL-loss. Clearly,
    "sum" aggregation results in a larger KL-loss value compared to "mean" by a factor
    of `n_layers`.

    NOTE: recall that sample-wise KL is obtained by summing over all dimensions,
    including Z. Also recall that in current 3D implementation of LVAE, no downsampling
    is done on Z. Therefore, to avoid emphasizing KL loss too much, we divide it
    by the Z dimension of input image in every case.

    Parameters
    ----------
    kl_type : Literal["kl", "kl_restricted"]
        The type of KL divergence loss to compute.
    topdown_data : dict[str, torch.Tensor]
        A dictionary containing information computed for each layer during the top-down
        pass. The dictionary must include the following keys:
        - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
        - "z": The sampled latents for each layer. Shape of each tensor is
        (B, layers, `z_dims[i]`, H, W).
    rescaling : Literal["latent_dim", "image_dim"]
        The rescaling method used for the KL-loss values. If "latent_dim", the KL-loss
        values are rescaled w.r.t. the latent space dimensions (spatial + number of
        channels, i.e., (C, [Z], Y, X)). If "image_dim", the KL-loss values are
        rescaled w.r.t. the input image spatial dimensions.
    aggregation : Literal["mean", "sum"]
        The aggregation method used to combine the KL-loss values across layers. If
        "mean", the KL-loss values are averaged across layers. If "sum", the KL-loss
        values are summed across layers.
    free_bits_coeff : float
        The free bits coefficient used for the KL-loss computation.
    img_shape : Optional[tuple[int]]
        The shape of the input image to the LVAE model. Shape is ([Z], Y, X).

    Returns
    -------
    kl_loss : torch.Tensor
        The KL divergence loss. Shape is (1, ).
    """
    kl = torch.cat(
        [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_type]],
        dim=1,
    )  # shape: (B, n_layers)

    # Apply free bits (& batch average)
    kl = free_bits_kl(kl, free_bits_coeff)  # shape: (n_layers,)

    # In 3D case, rescale by Z dim
    # TODO If we have downsampling in Z dimension, then this needs to change.
    if len(img_shape) == 3:
        kl = kl / img_shape[0]

    # Rescaling
    if rescaling == "latent_dim":
        for i in range(len(kl)):
            latent_dim = topdown_data["z"][i].shape[1:]
            norm_factor = np.prod(latent_dim)
            kl[i] = kl[i] / norm_factor
    elif rescaling == "image_dim":
        kl = kl / np.prod(img_shape[-2:])

    # Aggregation
    if aggregation == "mean":
        kl = kl.mean()  # shape: (1,)
    elif aggregation == "sum":
        kl = kl.sum()  # shape: (1,)

    return kl

get_reconstruction_loss(reconstruction, target, likelihood_obj) #

Compute the reconstruction loss (negative log-likelihood).

Parameters:

Name Type Description Default
reconstruction Tensor

The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), where C is the number of output channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

required
target Tensor

The target image used to compute the reconstruction loss. Shape is (B, C, [Z], Y, X), where C is the number of output channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

required
likelihood_obj Likelihood

The likelihood object used to compute the reconstruction loss.

required

Returns:

Type Description
Tensor

The recontruction loss (negative log-likelihood).

Source code in src/careamics/losses/lvae/losses.py
def get_reconstruction_loss(
    reconstruction: torch.Tensor,
    target: torch.Tensor,
    likelihood_obj: Likelihood,
) -> dict[str, torch.Tensor]:
    """Compute the reconstruction loss (negative log-likelihood).

    Parameters
    ----------
    reconstruction: torch.Tensor
        The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), where C is the
        number of output channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
    target: torch.Tensor
        The target image used to compute the reconstruction loss. Shape is
        (B, C, [Z], Y, X), where C is the number of output channels
        (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
    likelihood_obj: Likelihood
        The likelihood object used to compute the reconstruction loss.

    Returns
    -------
    torch.Tensor
        The recontruction loss (negative log-likelihood).
    """
    # Compute Log likelihood
    ll, _ = likelihood_obj(reconstruction, target)  # shape: (B, C, [Z], Y, X)
    return -1 * ll.mean()

hdn_loss(model_outputs, targets, config, gaussian_likelihood, noise_model_likelihood) #

Loss function for HDN.

Parameters:

Name Type Description Default
model_outputs tuple[Tensor, dict[str, Any]]

Tuple containing the model predictions (shape is (B, target_ch, [Z], Y, X)) and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).

required
targets Tensor

The target image used to compute the reconstruction loss. In this case we use the input patch itself as target. Shape is (B, target_ch, [Z], Y, X).

required
config LVAELossConfig

The config for loss function containing all loss hyperparameters.

required
gaussian_likelihood GaussianLikelihood

The Gaussian likelihood object.

required
noise_model_likelihood NoiseModelLikelihood

The noise model likelihood object.

required

Returns:

Name Type Description
output Optional[dict[str, Tensor]]

A dictionary containing the overall loss ["loss"], the reconstruction loss ["reconstruction_loss"], and the KL divergence loss ["kl_loss"].

Source code in src/careamics/losses/lvae/losses.py
def hdn_loss(
    model_outputs: tuple[torch.Tensor, dict[str, Any]],
    targets: torch.Tensor,
    config: LVAELossConfig,
    gaussian_likelihood: GaussianLikelihood | None,
    noise_model_likelihood: NoiseModelLikelihood | None,
) -> dict[str, torch.Tensor] | None:
    """Loss function for HDN.

    Parameters
    ----------
    model_outputs : tuple[torch.Tensor, dict[str, Any]]
        Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
        and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
    targets : torch.Tensor
        The target image used to compute the reconstruction loss. In this case we use
        the input patch itself as target. Shape is (B, `target_ch`, [Z], Y, X).
    config : LVAELossConfig
        The config for loss function containing all loss hyperparameters.
    gaussian_likelihood : GaussianLikelihood
        The Gaussian likelihood object.
    noise_model_likelihood : NoiseModelLikelihood
        The noise model likelihood object.

    Returns
    -------
    output : Optional[dict[str, torch.Tensor]]
        A dictionary containing the overall loss `["loss"]`, the reconstruction loss
        `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
    """
    if gaussian_likelihood is not None:
        likelihood = gaussian_likelihood
    elif noise_model_likelihood is not None:
        likelihood = noise_model_likelihood
    else:
        raise ValueError("Invalid likelihood object.")
    # TODO refactor loss signature
    predictions, td_data = model_outputs

    # Reconstruction loss computation
    recons_loss = config.reconstruction_weight * get_reconstruction_loss(
        reconstruction=predictions,
        target=targets,
        likelihood_obj=likelihood,
    )
    if torch.isnan(recons_loss).any():
        recons_loss = 0.0

    # KL loss computation
    kl_weight = get_kl_weight(
        config.kl_params.annealing,
        config.kl_params.start,
        config.kl_params.annealtime,
        config.kl_weight,
        config.kl_params.current_epoch,
    )
    kl_loss = (
        _get_kl_divergence_loss_denoisplit(
            topdown_data=td_data,
            img_shape=targets.shape[2:],
            kl_type=config.kl_params.loss_type,
        )
        * kl_weight
    )

    net_loss = recons_loss + kl_loss  # TODO add check that losses coefs sum to 1
    output = {
        "loss": net_loss,
        "reconstruction_loss": (
            recons_loss.detach()
            if isinstance(recons_loss, torch.Tensor)
            else recons_loss
        ),
        "kl_loss": kl_loss.detach(),
    }
    # https://github.com/openai/vdvae/blob/main/train.py#L26
    if torch.isnan(net_loss).any():
        return None

    return output

musplit_loss(model_outputs, targets, config, gaussian_likelihood, noise_model_likelihood=None) #

Loss function for muSplit.

Parameters:

Name Type Description Default
model_outputs tuple[Tensor, dict[str, Any]]

Tuple containing the model predictions (shape is (B, target_ch, [Z], Y, X)) and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).

required
targets Tensor

The target image used to compute the reconstruction loss. Shape is (B, target_ch, [Z], Y, X).

required
config LVAELossConfig

The config for loss function (e.g., KL hyperparameters, likelihood module, noise model, etc.).

required
gaussian_likelihood GaussianLikelihood

The Gaussian likelihood object.

required
noise_model_likelihood Optional[NoiseModelLikelihood]

The noise model likelihood object. Not used here.

None

Returns:

Name Type Description
output Optional[dict[str, Tensor]]

A dictionary containing the overall loss ["loss"], the reconstruction loss ["reconstruction_loss"], and the KL divergence loss ["kl_loss"].

Source code in src/careamics/losses/lvae/losses.py
def musplit_loss(
    model_outputs: tuple[torch.Tensor, dict[str, Any]],
    targets: torch.Tensor,
    config: LVAELossConfig,
    gaussian_likelihood: GaussianLikelihood | None,
    noise_model_likelihood: NoiseModelLikelihood | None = None,  # TODO: ugly
) -> dict[str, torch.Tensor] | None:
    """Loss function for muSplit.

    Parameters
    ----------
    model_outputs : tuple[torch.Tensor, dict[str, Any]]
        Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
        and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
    targets : torch.Tensor
        The target image used to compute the reconstruction loss. Shape is
        (B, `target_ch`, [Z], Y, X).
    config : LVAELossConfig
        The config for loss function (e.g., KL hyperparameters, likelihood module,
        noise model, etc.).
    gaussian_likelihood : GaussianLikelihood
        The Gaussian likelihood object.
    noise_model_likelihood : Optional[NoiseModelLikelihood]
        The noise model likelihood object. Not used here.

    Returns
    -------
    output : Optional[dict[str, torch.Tensor]]
        A dictionary containing the overall loss `["loss"]`, the reconstruction loss
        `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
    """
    assert gaussian_likelihood is not None

    predictions, td_data = model_outputs

    # Reconstruction loss computation
    recons_loss = config.reconstruction_weight * get_reconstruction_loss(
        reconstruction=predictions,
        target=targets,
        likelihood_obj=gaussian_likelihood,
    )
    if torch.isnan(recons_loss).any():
        recons_loss = 0.0

    # KL loss computation
    kl_weight = get_kl_weight(
        config.kl_params.annealing,
        config.kl_params.start,
        config.kl_params.annealtime,
        config.kl_weight,
        config.kl_params.current_epoch,
    )
    kl_loss = (
        _get_kl_divergence_loss_musplit(
            topdown_data=td_data,
            img_shape=targets.shape[2:],
            kl_type=config.kl_params.loss_type,
        )
        * kl_weight
    )

    net_loss = recons_loss + kl_loss
    output = {
        "loss": net_loss,
        "reconstruction_loss": (
            recons_loss.detach()
            if isinstance(recons_loss, torch.Tensor)
            else recons_loss
        ),
        "kl_loss": kl_loss.detach(),
    }
    # https://github.com/openai/vdvae/blob/main/train.py#L26
    if torch.isnan(net_loss).any():
        return None

    return output