Note - this is part 2 of a two article series on Double Descent. Part 1 is available here.
In our previous discussion of the double descent phenomenon, we have made use of a piecewise linear model that we fit to our data. This may seem a somewhat strange choice, but we made it for a simple reason: you can rigorously prove that it will show double descent for essentially any dataset! This does not seem to be in the literature anywhere, so we've taken it upon ourselves to at least sketch the argument to provide some mathematical justification for why double descent occurs.
This article is made of three parts. First, we provide a very short discussion of the behavior of the model with a small number of segments—the classical regime. Second, we will explain why the model must perform poorly at the interpolation threshold. Third, we will identify what happens in the limit of infinitely many linear pieces.
Throughout this discussion, we will see that the behavior boils down to two core principles:
1. At the interpolation threshold, there can often be a single choice of model that works, and there is no reason to believe that said model will be good.
2. In the limit of infinitely large models, there will be a vast number of interpolating models, and we can pick the best amongst them.
Throughout this discussion, we will use a simple example dataset to demonstrate our points. This is the collection of points:
This example dataset has no noise, and is simply a quadratic function sampled at 6 equally spaced points. We will get to know this dataset well.
Let's be precise about exactly what model we are working with. We work entirely in one dimension, so our input data is a vector , and our target is a vector . Our model will attempt to fit a piecewise linear function to this dataset, and the way we'll do that is to pick knot points where our linear function will be allowed to bend. This will be represented by use of radial basis functions, in particular we will define:
and say that our model will be written as
where , , and are learnable parameters. It can be readily checked (since the absolute value function is linear aside from a singularity at the origin) that this is a parametrization of piecewise linear functions with breakpoints only allowed at the points. We will consider the knot points to be fixed, and picked as independent and identically distributed random points from a continuous distribution with density . Typically we take to be uniform on the range of , although we discuss the problem in generality.
We will fit this model by ordinary least squares regression on our learnable parameters. As is standard, we will seek the minimum norm solution when the system is under-determined (when we have more parameters than data points). To simplify the analysis, we will assume we are only looking to minimize (excluding and from the minimization).
While not the focus of the document, let's take a moment to consider what happens when we are well below the interpolation threshold. Without any knot points this is ordinary least squares regression.
As our dataset is not linear, this fit is not particularly great, and we will be able to improve our fit by adding a point.
How much this improves performance depends heavily on the dataset, the ground truth, and choice of random point. Either numerical testing or an exact computation shows that on average the performance is improved versus the linear solution with regards to -th power of any -norm with . This includes the MSE and the MAE metrics.
So far, this perfectly matches traditional statistical intuition. Our simple linear fit is poor since the ground truth function is non-linear. Expanding our model to allow for non-linearity should improve the fit, as long as there is sufficient data. In this case, out six data points are enough to provide a better model when you allow for a single bend at a random location.
We now turn our attention to the behavior at the interpolation threshold. Our goal here will be to illustrate why this will perform poorly by showing that, for this model, the average error (averaged over our random knot points) is infinite!
Let's visualize an example of a model at the interpolation threshold.
In this case, we see that we have four knot points placed between the five right-most data points, and then no knot point between the two left-most ones. The main takeaway we want you to have is that, for this choice of points, there is no freedom in how the interpolating model is selected once the knot points are fixed. There is a unique line which passes through the left two points, and if we want to exactly interpolate the data we must start with that line.
This line intersects the vertical blue line denoting the knot point at at a point which is uniquely determined by the line on the left, and then it must bend to pass through the data point at . From there it hits the knot point at , at which point again it must bend to hit the point at , and so on. There is never any choice since it needs to be an interpolating function, so the entire solution is fixed.
This is one of the main lessons of what happens at the interpolation threshold—there is rigidity in what model you fit. Indeed, it is not uncommon that there is exactly one such choice (and a count of variable and constraints tells us that this should frequently happen when we have two fewer knot points than data points, although this is not guaranteed). If there is only one choice, we have no control over whether that model is good or not, and we should expect a bad fit.
To see an example that is a bad fit, let's move the left-most knot point towards the right.
This performance is extremely bad—jumping up to almost a value of when predicting a function whose maximum value is in this region. Indeed, we can shift again and make it even more pronounced, jumping all the way to almost .
This phenomena can be made so pronounced that our average error can be made infinite for all -norms.
First, let's focus on what this picture is showing. We are throwing down four random knot points, three of which are to the right of (for concreteness, say the first three), and one (say the last) occurs just to the left of . To be explicit, let's make that left-most point fall in the interval of . This whole event occurs with probability
We will throw out all constants as they are immaterial for the argument being made, and concentrate only on how this depends on .
When this occurs, the left-most segment will hit the left-most knot point somewhere below the point (since the dataset comes from a quadratic function that is curving up). To be concrete let's say that we ensure this by picking sufficiently small that it is no larger than . Then, at that knot point, the next linear segment will need to go from a point no better than , which means that it needs to have a slope of at least to pass through the point .
With this enormous slope, we have our key component. It means that our function has started oscillating wildly, and we simply need to show it is wild enough to make the average error infinite under any norm. Notice that we have no other knot points in the region from . On that region, our interpolating function takes on a value no smaller than , where is the slope, and thus the value and the error are both of order on that interval of width .
Now we finish the argument. By considering the sequence of we can create a sequence of disjoint events which we can use to lower-bound our expected error. In particular, letting be this piecewise linear interpolating function, and be the ground-truth quadratic, we see that we may estimate (discarding constants)
where the norm is understood as the integral over the range (in essence the infinite data limit). Thus the average error, as measured by the -norm (for ), is infinite. Indeed, for many reasonable metrics beyond the -norm, this error should be infinite.
If you follow through this argument in generality, you get the following fact:
Proposition 1. For any dataset with at points, not all co-linear, any bounded ground truth function , and any , then the best fit model of this type, with knot points has
The moral of this story is that these models are so rigid, that bad behavior is guaranteed independent of the dataset or ground truth. Simply the fact that you have a unique interpolating function that can only bend at random points forces the fit to be poor on average when you are at the interpolation threshold.
We now understand why the model performs poorly at the interpolation threshold. In essence, there is only one model that can be used, and that model has no reason to be a reasonable interpolating function at all! This provides the first hint as to why larger models might be better: they provide many choices about which interpolating model to pick, and perhaps we can structure the model in such a way that the interpolation of the prediction between points is much better behaved.
To see why this is so, suppose we have a nice function (at least twice continuously differentiable) and we select a large number of random knot points from our density . Recall that the way we are selecting the model is to minimize , which is the norm of the coefficients we put in front of our basis functions. Our goal is to provide an interpretation of this norm in terms of our smooth function .
In the above figure we see three consecutive knot points: . We will let and represent the slopes of the two segments adjacent to . denotes the distance between the midpoints of those two segments. We now collect three facts:
1. Since and are two consecutive intervals in our piecewise linear function, we can express the difference in the slope in terms of the coefficients of the model. In particular, note that , , and for all are all linear across both intervals. This means that they all contribute the same constant expression to the slope on both sides, thus the difference in slopes is simply the difference in slopes between the two sides of . This difference is . Thus:
2. Since the points on our piecewise linear interpolation are assumed to be close together, and all lie on our smooth function , we can say that both and are approximately equal to the derivative of at the midpoints of the intervals. Thus, their difference quotient is approximately the second derivative at :
3. Finally, let us collect two approximations for the probability that a randomly sampled point from lands in the interval connecting the two midpoints. On the one hand, the definition of density tells us this probability is obtained by integrating over the interval, which (because the interval is small) is approximately equal to the rectangular approximation . On the other hand, we sampled points from , and one landed in that interval, so we may also approximate it by . Thus:
Note: this third approximation is too rough to fully rigorously derive what follows. Instead, you need to consider an intermediate scale made by collecting together, for instance, many points. Since and , this lets us both assume the second derivative is approximately constant in that region and that the length of the interval has random fluctuations much smaller than the length itself. Given this is not a formal paper, we do not follow this any further.
We can try to understand what our norm minimization means in terms of the curve. By our second bullet point, we may write:
Applying the third bullet point to one of the yields:
The last equation is a discrete approximation to the integral of , so we can conclude:
We typically consider the case where is constant on an interval, and zero outside it, so if we assume that outside our interval (it is linear outside the interval), then we may conclude that we are trying to minimize
This is the key observation. By minimizing the -norm of our parameters, we are actually minimizing the integral of the square of the second derivative of our smooth interpolating function. Since second derivatives tell us the change in derivative, this can be interpreted as trying to find an interpolating function whose derivative changes as little as possible, or equivalently one which is as close to linear as possible.
This idea is so fundamental, it has been studied extensively before and goes by the name of the natural cubic spline (see for example these notes  or Chapter 11 of A&G ). This is a well studied class of interpolating functions.
This is only one portion of a full proof. Some additional work, such as showing that any convergent subsequence of discrete minimizers must converge to a function that is twice differentiable, is needed. However, once done, this produces the following theorem.
Theorem. Take a sequence of our approximating functions associated to the first from an infinite sequence of independent and identically distributed knot points drawn uniformly from an interval containing your dataset. Then the sequence almost surely converges uniformly on that interval to the natural cubic spline.
We picture the cubic spline and an interpolating function with 1000 pts below. Notice that, as predicted by the theory above, these pictures are indistinguishable.
Let's recall what we've seen.
In the first section, we briefly discussed the classical regime to see that indeed adding a single point does better than the simple linear fit. This tells us that, for this problem, there is a benefit in adding non-linearity.
In the second section, we continue to add points until we reach the exact same number of parameters as we have points. This is the point at which we expect interpolation to hold. In this case, we show that no matter the choice of data or target function, the average error must be infinite (as measured by any -norm) simply owing to the fact that there is often only a single interpolating function, which with reasonably high probability is wildly behaved.
Finally, in the third section, we took the limit to an infinite number of points and saw that our condition of taking the minimum norm solution corresponds to taking the smoothest interpolating function as measured by the integral of the square of the second derivative. This is exactly the energy minimized by the natural cubic spline interpolation.
In this case, we've reduced double descent to a completely unremarkable fact: that for a quadratic function, the cubic spline interpolating the points is a better approximation than any piecewise linear function with four or fewer segments! The fact that the model performs poorly at the interpolation threshold of five segments is an unavoidable consequence of the rigidity of the model forcing a single choice of interpolating function over which we have essentially no control.
This exemplifies one of the reasons why people believe that double descent occurs: that the interpolating functions you find using very large models can be better behaved between the points you are interpolating, and thus can match the ground truth better if you structure your large model in a way to build in an inductive bias towards the correct type of solution.
In many ways, one of the most powerful aspects of modern deep models is that the architecture can be tuned to be specific to the type of problem at hand (convolutions for images, recurrent for text, and now transformers for either). This tuning changes the types of functions that can be found, and thus changes the way that these models interpolate between the data points. It is a reasonable conjecture to believe that this tuning contributes to the observation of double descent commonly seen today.
Thanks for reading! To learn more about machine learning, check out our self-paced courses, our YouTube videos, and the Dive into Deep Learning textbook. If you have any comments, ideas, etc related to MLU-Explain articles, feel free to reach out directly. The code is available here.