Browse Definitions :
automated machine learning (AutoML) 10 popular libraries to use for machine learning projects
X
Definition

overfitting in machine learning

What is overfitting in machine learning?

Overfitting in machine learning occurs when a model excessively fits the training data, capturing both relevant patterns and inconsequential noise, resulting in inaccurate predictions of new data. Simpler models are less susceptible to capturing noise or irrelevant patterns, which increases the likelihood of generalizing effectively to unseen data.

For example, imagine a company using machine learning to select a few candidates to interview from a large set of resumes based solely on the resume content. The model can consider relevant factors, such as education, experience and skills. However, it overly fixates on font choices, rejecting highly qualified applicants for using Helvetica rather than Times New Roman.

Why does overfitting occur?

Most factors contributing to overfitting can be found in the model, data or training methods. If a machine learning model is too complex, it memorizes training data closely rather than learning the relevant underlying pattern.

Graph showing underfitting in machine learning.
Example of underfitting
Graph showing right fit in machine learning.
Example of right fit
Graph showing overfitting in machine learning.
Example of overfitting

If the training data contains too much noise or if the training data set is too small, the model does not have enough good data to distinguish between signal and noise. If a model trains for too long -- even with optimized data and model -- it starts learning noises, reducing performance the longer it trains. Another potential pitfall is repeatedly testing a model on the same validation data, leading to implicit overfitting to a stale data set.

Overfitting vs. underfitting

Underfitting is the opposite of overfitting in that the machine learning model doesn't fit the training data closely enough, thus failing to learn the pattern in the data. Underfitting can be caused by using a too-simple model for a complex problem.

In the above example where a company evaluates resumes with machine learning, an underfit model is too simplistic and fails to capture the relationship between resume contents and job requirements. For example, the underfit model may select all resumes containing specific keywords, such as Java and JavaScript, even if only JavaScript skills are required for the position. The learning model focuses too much on just the word Java, and skills are completely different for JavaScript. It then fails to detect suitable candidates in the training and new data.

How to detect overfit models

One sign of an overfit model is when it performs well on the training data but poorly on new data. However, there are other methods to test the model's performance more effectively.

K-fold cross-validation is an essential tool in assessing the performance of a model. The training data is randomly split into K subsets of equal size, referred to as folds. One fold is reserved for final validation, and the model is trained on the remaining folds. The model then validates the remaining fold and calculates the performance metrics. This process is performed K times, using a different fold as the validation fold during each iteration. The performance metrics are then averaged to get a single overall performance measure for the model.

Chart showing k-fold validation and the selection of validation data.
Starting with training and validation (or test) data, the training data is split into equally sized chunks. During K-fold validation, one chunk is designated as validation data. As the process repeats, a new chunk is selected as validation each time.

Technically, two learning curves are generated for one analysis. One learning curve is generated on the training data set to evaluate how the model is learning, and the other curve is generated on the validation set, which evaluates how well the model is generalizing to new data. Then, the learning curves plot performance metrics, such as error or accuracy, against the number of training data points.

Two graphs showing underfit and overfit scenarios.
Differences between underfit and overfit scenarios, showing how the training error and validation errors increase, decrease and converge

As the data set increases, patterns in the performance metrics begin to emerge. When the training error and the validation error plateau, that indicates that adding more data does not meaningfully change the fit. A learning curve for an underfitting model trends close and high. A learning curve for an overfitting model contains lower error values, but there is a gap between the validation and training results, indicating the model is underperforming on the validation data.

Organizations must improve the model and data to prevent overfitting in machine learning.

Improve the model

Here are some ways to refine and optimize models to decrease the risk of overfitting in machine learning.

Simpler models

Understanding the problem and selecting the appropriate machine learning algorithm are crucial at the beginning of a project. While cost evaluation and performance optimization are important, beginners should start with the simplest algorithm to avoid complications and improve generalization. Simple algorithms, such as k-means clustering or k-nearest neighbors, offer more straightforward interpretation and debugging.

Feature selection

In machine learning, features are the individual measurable properties or characteristics of the data used as inputs for training a model. Feature selection identifies which features are the most useful for the model to learn, which reduces the model's dimensionality.

Regularization

The risk of overfitting increases with the complexity of the model. Regularization is putting constraints on that model during training to avoid complications.

During the training process, the weights of the machine learning model -- or coefficients -- are adjusted to minimize the loss function, which represents the difference between the predicted outputs of a model and the actual target values. The loss function can be represented by the following:

min⍵→L(⍵→)

Regularization adds a new term α|| ⍵→ || to the loss function and then solves for the set of weights that minimizes the output.

min⍵→L(⍵→) + α|| ⍵→ ||

There are different ways to do this, depending on the model type.

Ridge regression

Ridge regression is a linear regression technique that adds the sum of the squares of the weights to the loss function during training, aiming to prevent overfitting by keeping the coefficients as small as possible without reducing them to zero.

ridge regression equation

LASSO regression

Least absolute shrinkage and selection operator (LASSO) regression adds the sum of the absolute values of the model's weights to the loss function. This automatically performs feature selection by eliminating the weights of the least important features.

lasso regression equation

Elastic net regression

Elastic net regression adds a regularization term that is the sum of ridge and LASSO regression, introducing the hyperparameter γ, which controls the balance between ridge regression (γ = 1) and LASSO regression (γ = 0) and determines how much automatic feature selection is done on the model.

elastic net regression equation

Early stopping

This method works for iterative learning algorithms, such as gradient descent. A model learns with more data. As the model learns and more data is provided, the prediction error on both the training and validation sets goes down. When too much data is added, overfitting begins to occur, and the error rate on the validation set starts to increase. Early stopping is a form of regularization that stops model training once the validation data's error rate reaches its minimum or when a plateau is detected.

Graph representing early stopping.
Early stopping scenario between time training a model and its performance metric

Dropout

Dropout is a regularization technique used in deep neural networks. Each neuron has a probability -- known as the dropout rate -- that it is ignored or "dropped out" at each data point in the training process. During training, each neuron is forced to adapt to the occasional absence of its neighbors and rely more on its inputs. This leads to a stronger, more resilient network with reduced susceptibility to minor input variations, which minimizes the risk of the network mistaking noise for meaningful data. Adjusting the dropout rate can address overfitting by increasing it or underfitting by decreasing it.

Ensemble methods

Ensemble methods aggregate predictions from multiple models toward the end of a machine learning project, which reduces both bias and variance, leading to more comprehensive predictions. An example of an ensemble method is random forests, which build multiple decision trees during training. Each tree is trained on a random subset of the data and features. During prediction, the random forest aggregates the predictions of individual trees to produce a final prediction, often achieving high accuracy and robustness against overfitting.

Improve the data

The data is just as vital as the model, so organizations can do the following to improve data.

Training with more data

A large training data set provides a more comprehensive representation of the underlying problem, enabling the model to learn the true patterns and dependencies rather than memorizing specific instances.

Data augmentation

Data augmentation helps reduce overfitting by copying one training data instance and then altering it slightly so it is learnable to the model but not detectable by a human. The model has more opportunities to learn the desired pattern, while increasing its tolerance for different environments. Data augmentation is especially helpful in balancing a data set because it includes more underrepresented data, helping to improve the model's ability to generalize across diverse scenarios and avoid biases in the training data.

This was last updated in May 2024

Continue Reading About overfitting in machine learning

Networking
  • What is FTP?

    File Transfer Protocol (FTP) is a network protocol for transmitting files between computers over TCP/IP connections.

  • What is a VPN?

    A virtual private network (VPN) is a service that creates a safe, encrypted online connection.

  • What is network scanning? How to, types and best practices

    Network scanning is a procedure for identifying active devices on a network by employing a feature or features in the network ...

Security
CIO
HRSoftware
Customer Experience
Close