μ & σ

Distribution-Free, Risk-Controlling Prediction Sets

  • Stephen Bates*, Anastasios Angelopoulos*, Lihua Lei*, Jitendra Malik, and Michael I. Jordan

@misc{bates-rcps,
	title={Distribution Free, Risk Controlling Prediction Sets},
	author={Bates, Stephen and Angelopoulos, Anastasios Nikolas and Lei, Lihua and Malik, Jitendra and Jordan, Michael I.},
	url={https://arxiv.org/abs/2101.02703},
	journal={arXiv:2101.02703},
	year={2021}
}
					

Summary

We show how to generate set-valued predictions for any predictor such that the expected loss on a future test point is controlled at a user-specified level, on any dataset, for a broad class of losses. Our approach provides explicit finite-sample guarantees by using a holdout set to calibrate the size of the prediction sets. This framework enables simple, distribution-free, rigorous, instance-wise error control for many new tasks, and we demonstrate it in five large-scale machine learning problems: (1) classification problems where some mistakes are more costly than others; (2) multi-label classification, where each observation has multiple associated labels; (3) classification problems where the labels have a hierarchical structure; (4) image segmentation, where we wish to preict a set of pixels containing an object of interest (in our case, tumors); and (5) protein structure prediction. Our GitHub implementation showcases all of these examples.

Motivation

Black-box predictive algorithms have begun to be deployed in many real-world decision-making settings. Problematically, however, these algorithms are rarely accompanied by reliable uncertainty quantification. Algorithm developers often rely on having followed the standard training/validation/test paradigm to make assertions of accuracy, stopping short of any further attempt to provide an indication that an algorithm's predictions should be treated with skepticism. Thus, prediction failures will often be silent ones, which is particularly alarming when the consequences of failure are significant.

We introduce a method for modifying a black-box predictor to return a set of plausible responses that limits the frequency of costly errors to a level chosen by the user. Returning a set of responses is a useful way to represent uncertainty, since such sets can be readily constructed from any existing predictor and, moreover, they are often interpretable. We call our proposed technique risk-controlling prediction sets (RCPS). The idea is to produce prediction sets that provide distribution-free, finite-sample control of a general loss. Like split-conformal prediction (see the last blog post), RCPS achieve this by using a small holdout dataset. However, the mathematical tools are entirely different, relying on concentration, a more general tool that applies to a wide range of problems. As a result, in contrast to the standard train/validation/test split paradigm which only estimates global uncertainty (in the form of overall prediction accuracy), RCPS allow the user to automatically return valid instance-wise uncertainty estimates for many prediction tasks.

Examples

RCPS can be used for any prediction problem, including classification with single, multiple, or hierarchical labels, regression, object segmentation, and more. They guarantee that the risk (expected loss) will be no more than $\gamma$ with probability $1-\delta$; we defer the mathematical treatment of RCPS to the paper. We will now demonstrate three (out of many) examples of RCPS below.

Protein Folding

A critical step in AlphaFold's protein folding pipeline involves predicting the distance between the $\beta$-carbons (the second-closest carbon to the side-chain) of each amino-acid. These distances are then used to determine the protein's 3D structure. We express uncertainty directly on the distances between $\beta$-carbons; this is only one possible choice of uncertainty quantification for protein folding using RCPS. Below we show an example of an instance-wise uncertainty estimate of the distances between amino acids in protein T0995 (PDB identifier 3WUY), a nitrilase from the organism Syechocystis sp. The distance between the RCPS and the true distance between the amino acids is guaranteed to be less than 2 Angstroms with probability $90\%$.

Protein Folding Sets
Prediction sets for protein T0995 from CASP-13. We show AlphaFold's predicted distances between residues of protein T0995 along with prediction sets at $\gamma=2$ Angstroms and $\delta=0.1$. The prediction set for the whole protein is the union of distance intervals for each pair of residues, and the right two panels report the distance from the point prediction to the lower and upper endpoints for each of these intervals.

Gut Polyp Segmentation

Imagine a surgical machine tasked with excising cancerous tumors identified via a computer vision algorithm and an endoscope. RCPS can guarantee, for example, that $90\%$ of the tumor pixels of each individual tumor in the image are identified with probability $90\%$. We show the resulting tumor segmentations below.

Polyp Segmentation Sets
Polyp Segmentations. We show examples of polyps along with prediction sets that capture $90\%$ of the true polyp pixels per polyp per image. White pixels are correctly identified polyp pixels, blue ones are spurious, and red ones are missed. The top two rows show examples with a single polyp per image, and the second two rows show examples with two polyps per image.

Hierarchical Classification

When classifying objects in a hierarchy (tree), one may want to output predictions that are only as granular as the classifier is certain. For example, if the classifier isn't sure if it sees a basset or a beagle, it could choose to output hound, a superclass of both. RCPS allows classifiers to make this decision adaptively to a single test-time example, depending on the difficulty of the image, while controlling any loss on the hierarchy. For example, RCPS could guarantee that the classifier will predict a superclass of the true label with probability $1-\delta$. It could also guarantee that the classifier will output an interior node of the tree that is not more than, say, $\gamma$ nodes away from a superclass of the true label with probability $1-\delta$. We show sets that satisfy the latter type of guarantee below on Imagenet.

Polyp Segmentation Sets
Hierarchical predictions. We show randomly selected examples of hierarchical prediction sets on Imagenet where the point prediction is incorrect but the prediction sets cover the true label. The black label is the ground truth class, the blue label is our prediction, and the red label is the top-1 output of a ResNet-18. Our prediction is an ancestor in the WordNet hierarchy of both the true class and the model's top-1 prediction. See the rightmost panel for an example subtree from the WordNet hierarchy.

Methods

To create an RCPS that controlls risk at level $\gamma$ with probability $1-\delta$, you need three ingredients:

Then you can define the RCPS as a random function $\mathcal{T}_{\hat{\lambda}}$, where the randomness arises when $\hat{\lambda}$ gets chosen by using the calibration data to control $L$ at the user's choice of $\gamma$ with probability $1-\delta$. In plain language, $\mathcal{T}$ defines the space of all possible uncertainty sets; the whole game is to pick a good choice of $\mathcal{T}$ and then tune the parameter $\lambda$ to achieve the desired risk control based on $L$. Concentration of measure provides a general toolkit to do this for any model, loss, dataset, $\gamma$, or $\delta$.

The Calibration Dataset

The user needs $n$ data points that the model has not been trained on to build an RCPS. Our guarantees hold for any $n$. However, $n$ is smaller than, say, $n=500$, and the user desires stringent risk control ($\gamma, \delta > 0.1$), the finite-sample guarantees detailed in our paper will result in conservatively large sets. If the user can tolerate not having the finite sample guarantees provided by concentration bounds, the CLT version of our method works surprisingly well in practice even with a small amount of samples (e.g. $n=35$ in the protein folding experiment) and the risk is controlled almost correctly.

The Loss

The loss can be rigorously defined as a function $L : \mathcal{Y} \times 2^\mathcal{Y} \to \mathbb{R}$. The first argument is the true label and the second is the prediction set, so the output represents how bad the prediction set is on a user-defined scale. Now we will give some examples of common losses.

In medical image classification with $K$ classes, missing an instance of the class 0 (ischemic stroke) might be worse than missing an instance of class 5 (concussion). We can simply express this by defining a penalty for each class, $\{p_i\}_{i=1}^K$, where $p_0 >> p_5$. An RCPS could use the naturally resulting loss on a label $y$ and prediction set $\mathcal{S}$, namely, $L(y,S) = p_y𝟙[y \in S]$. This simple example can be further extended to handle more detailed notions of consequence; the loss we wrote could be replaced by any function of $y$ and $\mathcal{S}$. For example, if the true class is ischemic stroke and the set includes hemorrhagic stroke, we might penalize the set less even though the true label wasn't included. The hierarchical example above takes this even further, defining the loss as a graph distance on hierarchically-defined labels.

Losses can also be defined for regression problems such as the image segmentation and protein folding examples above. In a simple 1D regression case where $y \in \mathbb{R}$, we could for example, take the loss to be $L(y,\mathcal{S})=|y-$proj$_{\mathcal{S}}(y)|$ where proj denotes the Euclidean projection. This would allow us to control the L2 distance between our predictive set and the true label $y$. In the multidimensional case $y \in \mathbb{R}^m$, we could take the loss to be (for example) $\underset{i \in 1,...,m}{\sup} |y_i-$proj$_{\mathcal{S}}(y)_i|$. The supremum could be replaced by any function, like the sample average, as in our protein folding example.

To summarize, RCPS allows error control of any loss.

The Family of Sets

The family of nested sets seem like the most mysterious part of the RCPS pipeline, but they are not difficult to construct at all. In fact, almost any predictive system has nested sets automatically built in. Consider, for example, a classifier that outputs a softmax score $s_i$ for each class $i=1,...,K$. A natural choice of nested sets emerges immediately: \begin{equation} \mathcal{T}_\lambda = \{ i : s_i \geq -\lambda \}. \end{equation} In English, $\mathcal{T}_\lambda$ includes all classes with scores above a threshold $-\lambda$ (the negative sign is an artifact of our definition of nesting). Of course, these sets are nested in the threshold. Any model that outputs a probability distribution at some point in its prediction process could use this same set construction, with slight and obvious modifications needed for multivariate outputs like our image segmentation and protein folding examples. However, RCPS also support intricate and exotic set constructions too; see the paper for a more detailed discussion.

We next include an animation of how the family of nested sets we used for the polyp segmentation example changes as $\lambda$ grows.

Choosing $\hat{\lambda}$

We've now defined all the preliminaries and can get to the main subject of the paper, which we will preview here: picking $\hat{\lambda}$. Define the risk of $\mathcal{T}_\lambda$ with respect to loss $L$ as $R(\lambda) = \mathbb{E}\big[L(Y,\mathcal{T}_\lambda(X))\big]$. Next assume we have access to a pointwise upper confidence bound for the risk, $\widehat{R}^+(\lambda)$ satisfying $\mathbb{P}\big(R(\lambda) \leq \widehat{R}^+(\lambda) \big) \geq 1-\delta$. For example, we might get $\widehat{R}^+(\lambda)$ via Hoeffding's inequality, although in practice, we use much more powerful concentration results to get tighter sets. Then, we can choose $\hat{\lambda}$ as the largest value of $\lambda$ that crosses $\gamma$: \begin{equation} \hat{\lambda} = \inf \big\{ \lambda : \widehat{R}^+(\lambda') < \gamma, \forall \lambda' \geq \lambda \big\}. \end{equation} We animate the algorithm for choosing $\hat{\lambda}$ below; reload the page to replay the animation. With this choice of $\hat{\lambda}$, $\mathcal{T}_\hat{\lambda}$ is a $(\gamma,\delta)$-RCPS, meaning with probability $1-\delta$, \begin{equation} R(\mathcal{T}_{\hat{\lambda}}) \le \gamma. \end{equation}

Conclusion

Congratulations! You made it to the end of the post. Now you know how the basic strategy for creating predictive sets that control tail probabilities of a general risk for any model, any dataset, in finite samples, for almost no computational cost. If you've enjoyed this, you will really enjoy the paper, which has a lot more detail. And if you'd like to use our method, check out our GitHub.