candle.uq_keras_utils.contamination_loss

candle.uq_keras_utils.contamination_loss#

candle.uq_keras_utils.contamination_loss(nout, T_k, a, sigmaSQ, gammaSQ)#

Function to compute contamination loss. It is composed by two terms: (i) the loss with respect to the normal distribution that models the distribution of the training data samples, (ii) the loss with respect to the Cauchy distribution that models the distribution of the outlier samples. Note that the evaluation of this contamination loss function does not make sense for any data different to the training set. This is because latent variables are only defined for samples in the training set.

Parameters:
  • nout (int) – Number of outputs without uq augmentation (in the contamination model the augmentation corresponds to the data index in training).

  • T_k – Keras tensor. Tensor containing latent variables (probability of membership to normal and Cauchy distributions) for each of the samples in the training set. (Validation data is usually augmented too to be able to run training with validation set, however loss in validation should not be used as a criterion for early stopping training since the latent variables are defined for the training only, and thus, are not valid when used in combination with data different from training).

  • a – Keras variable. Probability of belonging to the normal distribution

  • sigmaSQ – Keras variable. Variance estimated for the normal distribution

  • gammaSQ – Keras variable. Scale estimated for the Cauchy distribution