candle.uq_utils.compute_statistics_heteroscedastic

candle.uq_utils.compute_statistics_heteroscedastic#

candle.uq_utils.compute_statistics_heteroscedastic(df_data, col_true=4, col_pred_start=6, col_std_pred_start=7)#

Extracts ground truth, mean prediction, error, standard deviation of prediction and predicted (learned) standard deviation from inference data frame. The latter includes all the individual inference realizations.

Parameters:
  • df_data (pandas dataframe) – Data frame generated by current heteroscedastic inference experiments. Indices are hard coded to agree with current version. (The inference file usually has the name: <model>.predicted_INFER_HET.tsv).

  • col_true (int) – Index of the column in the data frame where the true value is stored (Default: 4, index in current HET format).

  • col_pred_start (int) – Index of the column in the data frame where the first predicted value is stored. All the predicted values during inference are stored and are interspaced with standard deviation predictions (Default: 6 index, step 2, in current HET format).

  • col_std_pred_start (int) – Index of the column in the data frame where the first predicted standard deviation value is stored. All the predicted values during inference are stored and are interspaced with predictions (Default: 7 index, step 2, in current HET format).

Returns:

Tuple of numpy arrays

  • Ytrue (numpy array): Array with true (observed) values

  • Ypred_mean (numpy array): Array with predicted values (mean of predictions).

  • yerror (numpy array): Array with errors computed (observed - predicted).

  • sigma (numpy array): Array with standard deviations learned with deep learning model. For heteroscedastic inference this corresponds to the sqrt(exp(s^2)) with s^2 predicted value.

  • Ypred_std (numpy array): Array with standard deviations computed from regular (homoscedastic) inference.

  • pred_name (string): Name of data colum or quantity predicted (as extracted from the data frame using the col_true index).