pomme.loss

class pomme.loss.Loss(keys=[])

Bases: object

A convenience class to store losses.

plot()

Plot the evolution of the losses.

renormalise(key)

Reset the norm to one over the current loss.

Parameters:

key (str) – Key of the variable to be renormalised.

renormalise_all()

Renormalise all losses.

reset()

Reset all losses and remove all stored losses.

tot()

Return the total loss.

class pomme.loss.SphericalLoss(model, origin='centre', weights=None)

Bases: object

Copmutes the deviation from spherical symmetry. This is quantified as the variacne of the data in each radial bin.

eval(var)

Evaluate the spherical loss.

Parameters:

var (torch.Tensor) – Variable for which the loss should be evaluated.

Returns:

The spherical loss for the given variable.

Return type:

torch.Tensor

pomme.loss.diff_loss(arr)

Differential loss, quantifying the local change in a variable along the cartesian axes.

pomme.loss.fourier_loss_1D(arr)

Loss based on the (1D) Fourier transform.

pomme.loss.haar_loss_1D(arr)

Loss based on the Haar wavelet transform.