This article is originally published at https://www.sharpsightlabs.com
It’s the bane of most machine learning developers.
You build a model that performs so well on the training data, and think “I’ve done such a good job!” while patting yourself on the back.
Then, you test the model later, only to find that it performs terribly.
Welcome to overfitting.
Overfitting is probably one of the most important concepts to understand machine learning, because it’s at the core of building models that work in the real world.
So in this blog post, I’m going to explain what overfitting is, what causes it, and a few high-levels ways to diagnose it.
If you need something specific, just click on any of the following links.
Table of Contents:
- Causes of Overfitting
- How to Diagnose Overfitting
- How to Prevent and Fix Overfitting
- Frequently Asked Questions
Let’s jump in.
Introduction to Overfitting
In machine learning, the concept of overfitting is one of the most important concepts that you need to understand, especially in the beginning.
In supervised learning, our goal is to build a model that can predict things …. either numbers (in the case of regression) or categories (in the case of classification).
And we want our models to predict well, which, at a high level, means that we want our models to achieve high values on machine learning evaluation metrics, such as , precision, recall, accuracy, etc.
What happens with many beginners, is that they inadvertantly build models that are too good.
… a model that performs too well on a performance metric.
Which can be a very bad thing.
Now, I know what you’re thinking …
“Performing too good on a performance metric? That’s bad?”
If your model is overfitting.
What is Overfitting
Overfitting is when a model learns to follow the patterns in the training data too closely, in a way that causes the model to fail when it’s used on new, previously unseen data.
It’s like memorizing very specific and narrow patterns that exist in the training data, when the real world problem that the model needs to work on is much more complicated.
This leads to an interesting phenomenon: the model performs extremely well on the training data, but performs poorly on new data that it’s never seen before.
A Simple Example
Let’s say that you’re learning a new subject. Something a little complicated, like calculus.
You’re learning calculus, and studying for a test.
The test will have 30 questions.
But you decide to only study 5 example questions to prepare for that test.
You spend several hours going over and over those 5 example questions, until you master them.
You study those 5 examples so much, that you score perfectly on them, ever time.
What will happen when you sit down to take the actual calculus test with 30 questions?
Will there be some other things on that test that weren’t represented in your 5 practice questions?
Might there be additional concepts and patterns in the broader calculus exam that were absent from your 5 practice questions?
So let’s say that you master those 5 training questions that you used to study and you scored 100% on them.
But, because the 30 question test is more complicated, you score only a 65% on the actual 30 question test.
That, my friend, is overfitting.
A More Technical Explanation of Overfitting
We train machine learning models to detect patterns in the data (at least, that’s what we do in supervised learning).
And we “train” our models with training data.
During this training process, the machine learning algorithms that we use to build our models adjust their internal parameters to reduce the difference between their target predictions and the actual values of the target in the training data. This is what we mean by “fitting.” We’re literally to fit the model to the training data to make the errors between the predicted value and the actual value small.
Ideally as small as possible (but without overfitting).
An Example in Regression
Let’s say that we have some data.
There’s clearly a pattern in the data.
We want to be able to predict that pattern.
To do so, we build a model. In this case, we build a polynomial regression line to fit the training data.
And here, we have an image of the training datapoints, the model, and the associated error terms.
The error terms shown are the distance from the model to the training data.
Now, in regression, where we’re trying to predict a numeric value (like the values in this training data), we typically optimize the model to minimize mean squared error.
Mean squared error is this quantity:
Where is the actual value (i.e., the value of the training datapoint) and is the value predicted by the model.
(Mean squared error is also called MSE.)
So the errors are represented by the value (). And notice that in mean squared error, we’re minimizing the sum of the squared errors.
Next, what happens if we build a model that brings all of those error terms to 0?
Here, I’ve trained a very flexible neural network to almost perfectly fit the data.
So, as you can see, the model basically hits all of the training datapoints.
You might be thinking “Wow, that’s great right? The model fits PERFECTLY.”
It fits the training points.
But what happens when we evaluate the model on new, previously unseen data that was generated from the same process as the training data?
The fit against the new test datapoints is a little bad.
We can actually measure this with mean squared error (which I’ll explain a little further down in the post).
But, let’s first talk about what causes overfitting.
Causes of Overfitting
Now that we’ve discussed what overfitting is, let’s look at the causes of overfitting.
The primary causes are:
- Small and Limited Training Data
- High Model Complexity
- Inadequate Data Split into Training and Validation Sets
Let’s talk about each of these, one at a time.
Small and Limited Training Data
One of the most common causes of overfitting is limited training data.
This can include a small dataset, in the sense that there are a small number of total examples.
It can also mean a dataset with a limited range of examples; a training dataset that fails to capture the full range of variation in the problem space.
When we train an algorithm on a limited or narrow set of training examples, the algorithm will learn the patterns specific only to those examples, and will fail to capture patterns that may exist in the broader possible set of examples.
For instance, let’s say we’re training computer vision system to detect faces. If we only train the system on people from 20-30 years old, it may later struggle with images outside of that range. In such a situation, we might train the model to have a very high training performance, but the model would likely have a bad general performance when tested on a previously unseen set of examples with a broader range of ages. That would be an example of overfitting.
That being said, you need to be careful that your training data is representative of the overall population of possible examples in your problem space.
You also need to be careful to have enough training data. This is particularly true for complex problem spaces, where the model needs to learn complex patterns. For example, deep learning (i.e., deep neural networks) are excellent at learning complex patterns in data, but they notoriously need a lot of training data, otherwise they are prone to overfitting.
High Model Complexity
Having just mentioned deep learning, it’s a good time to mention how model complexity can contribute to overfitting.
In machine learning, models that are too “flexible” are more prone to overfitting. “Flexible” models include models that have a large number of learnable parameters, like deep neural networks, or models that can otherwise adapt themselves in very fine-grained ways to the training data, such as gradient boosted trees.
For example, and as I’ve already mentioned, deep neural networks with many layers and neurons, tend to be highly flexible machine learning methods. On one hand, this enables them to find very complex patterns in data, which in turn, enables them to solve more complex problems. But on the other hand, this flexibility makes them prone to overfitting, where they adapt not only to the actual patterns in the data, but also adapt to the random “noise” in the data that they should ignore.
Ultimately, high model complexity can be good for solving complex problems, but when you build a complex model, you need to be more careful to avoid overfitting the training data.
A Bad Validation Set
The way that you split your data into training, validation, and test sets can also impact overfitting.
Specifically, problems with your validation set could cause problems with overfitting.
If the validation set fails to accurately represent the overall distribution of the data, the model may perform well on the validation set, but fail to generalize to the broader range of possible inputs and outputs seen in the overall problem space.
For example, imagine that you build a model to predict stock prices. If the validation data only contains data from period of market stability, then model evaluation on that validation set will be a poor indicator of the model’s general performance across a wider range of market conditions. The model might perform well on the validation set, but then when actually deployed, it would likely perform much worse in volatile market conditions.
That said, you need to be careful that your validation set is representative of the general problem space, and that the validation set accurately represents the overall distribution of the data.
The Basics of How to Diagnose Overfitting
Next, let’s quickly talk about how to diagnose overfitting.
There are a variety of ways to diagnose overfitting, so let’s look at the 3 most important.
- Compare Performance on Training vs. Validation Data
- Learning Curves
- Cross Validation
Let’s look at each of these individually.
Compare Performance on Training vs. Validation Data
The most common and arguably the most direct way to diagnose overfitting is by comparing the performance of the model on the training data vs the validation data.
As I noted previously, and as I explained at some length in my blog post about training, validation, and test datasets, we commonly split our training data into 3 datasets: one to train the model, one to “validate” the model, and one for a final test.
When we attempt to build a machine learning model to solve a particular task, we commonly build multiple models. We might try out several different algorithms (like logistic regression, decision trees, neural networks, etc). We might also experiment with a variety of settings for model hyperparameters (like tuning the learning rate, batch size, and number of layers of a neural network).
After we build these models, each one will report a particular performance on the training set.
For regression systems, we will probably compute training MSE (means squared error) or .
But although the training MSE gives us some sense of how well the model is working, what really matters is the performance on the validation set.
So we’ll then use the validation dataset and feed it into the models that we’ve trained. And we can use the predictions from those models on the validation data to compute our evaluation metrics (i.e., accuracy, precision, recall, MSE, etc) all over again.
If there’s a big discrepancy between the performance on the training data vs the validation data, that’s a good indication that the model is overfitting.
For example, if a regression model has a very low training MSE, but a much larger test MSE, then that’s a good signal of overfitting.
Similarly, if a classification model has a high training accuracy, but a much lower validation accuracy, it’s a good signal of overfitting.
A learning curve is a visualization technique that shows how the model performance (e.g., accuracy, MSE) changes over time during training. In these learning curves, we plot 2 curves: one for the training data and one for the validation data.
Additionally, there are actually a couple of variations of learning curves.
One type plots model performance (accuracy, loss, etc) vs the number of training epochs. This type of learning curve is appropriate for neural networks and deep networks, where we train the model over multiple epochs. We can then train the model over different numbers of epochs, score the performance on the training set and validation sets, and plot the model performance vs the number of epochs for both the training data and validation data.
A different type of learning curve plots model performance vs dataset size. In this type, you can train the model multiple times with different numbers of examples in the training dataset, and evaluate the model on the training data and the test data. Then you can plot the model performance for both training data and validation data vs the number of examples in the training set.
If a model is overfitting, the training curve will continuously improve, but the validation curve will plateaus or degrade. This indicates the model’s decreasing ability to generalize to new data.
Finally, let’s briefly talk about cross validation.
Cross validation is a method that enables us to better evaluate how well a model will generalize.
Like a validation set, it allows us to see how a model performs on unseen data.
But, instead of “validating” the model with a new dataset one time … cross validation does this multiple times.
To do this, instead of having one separate validation set, you actually take the training data and divide it into multiple blocks. So, for example, in 10-fold cross validation, you divide the training data into 10 blocks.
Then you train the data on 9 of those blocks and validate on the 10th block. Then you compute the model performance (i.e., the model error).
But instead of stopping there, you rotate the blocks, so that you use a new block as the validation block, train the model on the remaining blocks, validate the model on the new validation block, and compute the performance again.
You keep choosing a new block as the validation block (one that has not been the validation block yet) and redo the training/validation/evaluation process.
So with cross validation, you end up training and validating the data multiple times.
If the model consistently performs well across all of the training/validation splits, then this suggests that the model generalizes well.
Conversely, if the model shows significant performance fluctuations across different training/validation splits, then the model may be overfitting.
Cross validation is particularly effective because it tests the ability of the model to adapt to different training datasets.
I realize that cross validation is a little complicated, and it might be difficult to understand with my limited explanation here, so I’m going to write a full blog post about cross validation separately.
In this post, I’ve tried to explain the essentials of overfitting in clear, easy-to-understand language.
But, overfitting is a very complicated subject, and it’s really hard to cover everything in a single post.
That said, I’m going to write more about other aspects of overfitting in future blog posts, like:
- How to Diagnose Overfitting
- How to Prevent and Fix Overfitting
- Regularization as a Tool to Manage Overfitting
- … and more
Leave your Questions Below
Do you still have questions about overfitting?
What else do you want to know that I didn’t cover?
I want to hear from you.
Leave your questions, concerns and comments in the comments section at the bottom of the page.
Sign up for our email list
If you want to learn more about data science and machine learning, then sign up for our email list.
Every week, we publish free long-form articles about a variety of topics in machine learning, AI, and data science, including:
- Machine Learning
- Scikit Learn
- Deep Learning
- … and more
When you sign up for our email list, then we’ll deliver those tutorials to you, free, and direct to your inbox.
Please visit source website for post related comments.