candle.uq_keras_utils.AbstentionAdapt_Callback

candle.uq_keras_utils.AbstentionAdapt_Callback#

class candle.uq_keras_utils.AbstentionAdapt_Callback(acc_monitor, abs_monitor, alpha0, init_abs_epoch=4, alpha_scale_factor=0.8, min_abs_acc=0.9, max_abs_frac=0.4, acc_gain=5.0, abs_gain=1.0)#

This callback is used to adapt the parameter alpha in the abstention loss.

The parameter alpha (weight of the abstention term in the abstention loss) is increased or decreased adaptively during the training run. It is decreased if the current abstention accuracy is less than the minimum accuracy set or increased if the current abstention fraction is greater than the maximum fraction set. The abstention accuracy metric to use must be specified as the ‘acc_monitor’ argument in the initialization of the callback. It could be: the global abstention accuracy (abstention_acc), the abstention accuracy over the ith class (acc_class_i), etc. The abstention metric to use must be specified as the ‘abs_monitor’ argument in the initialization of the callback. It should be the metric that computes the fraction of samples for which the model is abstaining (abstention). The factor alpha is modified if the current abstention accuracy is less than the minimum accuracy set or if the current abstention fraction is greater than the maximum fraction set. Thresholds for minimum and maximum correction factors are computed and the correction over alpha is not allowed to be less or greater than them, respectively, to avoid huge swings in the abstention loss evolution.

__init__(acc_monitor, abs_monitor, alpha0, init_abs_epoch=4, alpha_scale_factor=0.8, min_abs_acc=0.9, max_abs_frac=0.4, acc_gain=5.0, abs_gain=1.0)#

Initializer of the AbstentionAdapt_Callback.

Parameters:
  • acc_monitor (keras.metric) – Accuracy metric to monitor during the run and use as base to adapt the weight of the abstention term (i.e. alpha) in the abstention cost function. (Must be an accuracy metric that takes abstention into account).

  • abs_monitor (keras.metric) – Abstention metric monitored during the run and used as the other factor to adapt the weight of the abstention term (i.e. alpha) in the asbstention loss function

  • alpha0 (float) – Initial weight of abstention term in cost function

  • init_abs_epoch (int) – Value of the epochs to start adjusting the weight of the abstention term (i.e. alpha). Default: 4.

  • alpha_scale_factor (float) – Factor to scale (increase by dividing or decrease by multiplying) the weight of the abstention term (i.e. alpha). Default: 0.8.

  • min_abs_acc (float) – Minimum accuracy to target in the current training. Default: 0.9.

  • max_abs_frac (float) – Maximum abstention fraction to tolerate in the current training. Default: 0.4.

  • acc_gain (float) – Factor to adjust alpha scale. Default: 5.0.

  • abs_gain (float) – Factor to adjust alpha scale. Default: 1.0.

Methods

__init__(acc_monitor, abs_monitor, alpha0[, ...])

Initializer of the AbstentionAdapt_Callback.

on_batch_begin(batch[, logs])

A backwards compatibility alias for on_train_batch_begin.

on_batch_end(batch[, logs])

A backwards compatibility alias for on_train_batch_end.

on_epoch_begin(epoch[, logs])

Called at the start of an epoch.

on_epoch_end(epoch[, logs])

Updates the weight of abstention term on epoch end.

on_predict_batch_begin(batch[, logs])

Called at the beginning of a batch in predict methods.

on_predict_batch_end(batch[, logs])

Called at the end of a batch in predict methods.

on_predict_begin([logs])

Called at the beginning of prediction.

on_predict_end([logs])

Called at the end of prediction.

on_test_batch_begin(batch[, logs])

Called at the beginning of a batch in evaluate methods.

on_test_batch_end(batch[, logs])

Called at the end of a batch in evaluate methods.

on_test_begin([logs])

Called at the beginning of evaluation or validation.

on_test_end([logs])

Called at the end of evaluation or validation.

on_train_batch_begin(batch[, logs])

Called at the beginning of a training batch in fit methods.

on_train_batch_end(batch[, logs])

Called at the end of a training batch in fit methods.

on_train_begin([logs])

Called at the beginning of training.

on_train_end([logs])

Called at the end of training.

set_model(model)

set_params(params)