Computer Scientists Prove Why Bigger Neural Networks Do Better
Introduction
Our species owes a lot to opposable thumbs. But if evolution had given us extra thumbs, things probably wouldn’t have improved much. One thumb per hand is enough.
Not so for neural networks, the leading artificial intelligence systems for performing humanlike tasks. As they’ve gotten bigger, they have come to grasp more. This has been a surprise to onlookers. Fundamental mathematical results had suggested that networks should only need to be so big, but modern neural networks are commonly scaled up far beyond that predicted requirement — a situation known as overparameterization.
In a paper presented in December at NeurIPS, a leading conference, Sébastien Bubeck of Microsoft Research and Mark Sellke of Stanford University provided a new explanation for the mystery behind scaling’s success. They show that neural networks must be much larger than conventionally expected to avoid certain basic problems. The finding offers general insight into a question that has persisted over several decades.
“It’s a really interesting math and theory result,” said Lenka Zdeborová of the Swiss Federal Institute of Technology Lausanne. “They prove it in this very generic way. So in that sense, it’s going to the core of computer science.”
The standard expectations for the size of neural networks come from an analysis of how they memorize data. But to understand memorization, we must first understand what networks do.
One common task for neural networks is identifying objects in images. To create a network that can do this, researchers first provide it with many images and object labels, training it to learn the correlations between them. Afterward, the network will correctly identify the object in an image it has already seen. In other words, training causes a network to memorize data. More remarkably, once a network has memorized enough training data, it also gains the ability to predict the labels of objects it has never seen — to varying degrees of accuracy. That latter process is known as generalization.
A network’s size determines how much it can memorize. This can be understood graphically. Imagine getting two data points that you place on an xy-plane. You can connect these points with a line described by two parameters: the line’s slope and its height when it crosses the vertical axis. If someone else is then given the line, as well as an x-coordinate of one of the original data points, they can figure out the corresponding y-coordinate just by looking at the line (or using the parameters). The line has memorized the two data points.
Neural networks do something similar. Images, for example, are described by hundreds or thousands of values — one for each pixel. This set of many free values is mathematically equivalent to the coordinates of a point in a high-dimensional space. The number of coordinates is called the dimension.
An old mathematical result says that to fit n data points with a curve, you need a function with n parameters. (In the previous example, the two points were described by a curve with two parameters.) When neural networks first emerged as a force in the 1980s, it made sense to think the same thing. They should only need n parameters to fit n data points — regardless of the dimension of the data.
“This is no longer what’s happening,” said Alex Dimakis of the University of Texas, Austin. “Right now, we are routinely creating neural networks that have a number of parameters more than the number of training samples. This says that the books have to be rewritten.”
Bubeck and Sellke didn’t set out to rewrite anything. They were studying a different property that neural networks often lack, called robustness, which is the ability of a network to deal with small changes. For example, a network that’s not robust may have learned to recognize a giraffe, but it would mislabel a barely modified version as a gerbil. In 2019, Bubeck and colleagues were seeking to prove theorems about the problem when they realized it was connected to a network’s size.
“We were studying adversarial examples — and then scale imposed itself on us,” said Bubeck. “We recognized it was this incredible opportunity, because there was this need to understand scale itself.”
In their new proof, the pair show that overparameterization is necessary for a network to be robust. They do it by figuring out how many parameters are needed to fit data points with a curve that has a mathematical property equivalent to robustness: smoothness.
To see this, again imagine a curve in the plane, where the x-coordinate represents the color of a single pixel, and the y-coordinate represents an image label. Since the curve is smooth, if you were to slightly modify the pixel’s color, moving a short distance along the curve, the corresponding prediction would only change a small amount. On the other hand, for an extremely jagged curve, a small change in the x-coordinate (the color) can lead to a dramatic change in the y-coordinate (the image label). Giraffes can become gerbils.
Bubeck and Sellke showed that smoothly fitting high-dimensional data points requires not just n parameters, but n × d parameters, where d is the dimension of the input (for example, 784 for a 784-pixel image). In other words, if you want a network to robustly memorize its training data, overparameterization is not just helpful — it’s mandatory. The proof relies on a curious fact about high-dimensional geometry, which is that randomly distributed points placed on the surface of a sphere are almost all a full diameter away from each other. The large separation between points means that fitting them all with a single smooth curve requires many extra parameters.
“The proof is very elementary — no heavy math, and it says something very general,” said Amin Karbasi of Yale University.
The result provides a new way to understand why the simple strategy of scaling up neural networks has been so effective.
Other research has revealed additional reasons why overparameterization is helpful. For example, it can improve the efficiency of the training process, as well as the ability of a network to generalize. While we now know that overparameterization is necessary for robustness, it is unclear how necessary robustness is for other things. But by connecting it to overparameterization, the new proof hints that robustness may be more important than was thought, a single key that unlocks many benefits.
“Robustness seems like a prerequisite to generalization,” said Bubeck. “If you have a system where you just slightly perturb it, and then it goes haywire, what kind of system is that? That’s not reasonable. I do think it’s a very foundational and basic requirement.”