Cost Function in Linear Regression
Last Updated :
23 Jul, 2025
Linear Regression is a method used to predict values by drawing the best-fit line through the data. When we first create a model, the predictions may not always match the actual data. To understand how well the model is performing we use a cost function. This function helps us to measure the difference between the predicted values and the actual data. In this article, we’ll see cost function in linear regression, what it is, how it works and why it’s important for improving model accuracy.
Aggregates the errors ( differences between predicted and actual values) across all data points.The cost function measures how well the model’s predictions match the actual data and guides the optimization of parameters to minimize errors and find the best fit.
How Does the Cost Function Work?
Lets understand it with a example, imagine we are building a linear regression model to predict house prices based on the size of the house (in square feet).
Here’s some training data:
Size (sq. ft.) | True Price (in $1000s) |
---|
500 | 50 |
---|
1000 | 100 |
---|
1500 | 150 |
---|
2000 | 200 |
---|
The linear regression equation is:
\hat y =w⋅x
where,
- \hat y is predicted house price
- x is size of the house (input feature)
- w is weight (slope of the line)
Our goal is to find the weight w that minimizes the difference between the predicted and actual prices. This difference is calculated using the cost function.
Understanding Mean Squared Error (MSE)
A commonly used cost function is Mean Squared Error (MSE). It finds larger errors which helps the model focus on reducing mistakes between predictions and actual values.
The MSE formula is:
J(\theta) = \frac{1}{m} \sum_{i=1}^{m} \left( h_\theta(x^{(i)}) - y^{(i)} \right)^2\Pi
Where:
- J(\theta) is the cost function
- m is the number of data points
- h_{\theta}(x^{(i)}) is the predicted value for the i-th data point
- y^{(i)} is the actual value.
Let’s calculate the MSE for initial predictions with w=0.04. Using this value of w, we can predict the prices for the given house sizes:
Size (x) (sq. ft.) | True Price (y) (in $1000s) | Predicted Price (y^) (in $1000s) |
---|
500 | 50 | 0.04×500=200.04×500=20 |
---|
1000 | 100 | 0.04×1000=400.04×1000=40 |
---|
1500 | 150 | 0.04×1500=600.04×1500=60 |
---|
2000 | 200 | 0.04×2000=800.04×2000=80 |
---|
Proceeding with our example Let’s calculate the MSE for our predictions. For each data point we square the error to ensure it’s positive then sum all the squared errors:
- 500= (30)^2=900
- 1000= (60)^2=3600
- 1500= (90)^2 =8100
- 2000= (120)^2
=14400
Now, sum the squared errors:
900+3600+8100+14400=27000
Divide by the number of points m=4:
MSE=\frac{27000}{4}=6750
Thus the MSE is 6750 representing the average squared error between the predicted and actual values. A high MSE shows that the model’s predictions are far from the actual values. To reduce this error we can adjust the parameters using Gradient Descent.
Role of Gradient Descent in Updating the Weights
Gradient descent is an optimization algorithm used to minimize the cost function and find the best-fit line for the model. Its goal is to iteratively adjust the weights of the model to reduce the error. Each iteration updates the weights in the direction that minimizes the cost function leading to the optimal set of parameters.
Gradient Descent in Updating the WeightsGradient descent works as follows:
1. Start with an initial guess for the weight w
2. Calculate the MSE for the current weight w
3. Find the gradient of the cost function with respect to w
4. Update the weight using the formula:
w = w - \alpha \frac{\partial J(w)}{\partial w}
Where:
- \alpha is the learning rate (step size)
- \frac{\partial J(w)}{\partial w}
is the gradient of the cost function.
5. Repeat the process until the MSE converges to a minimum value or a set number of iterations is reached.
Types of Cost Function in Linear Regression
While MSE is commonly used, there are other cost functions that can be used in linear regression each with its own advantages. These include:
1. Mean Absolute Error (MAE)
Mean Absolute Error (MAE) calculates the average of the absolute differences between actual and predicted values. Unlike MSE, MAE does not penalize large errors as heavily which can be useful when dealing with outliers or when we want a simpler interpretation of error.
Formula for MAE is:
\text{MAE} = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i|
This formula sums up all the absolute differences and divides by the total number of predictions to give us an average.
2. Root Mean Squared Error (RMSE)
Root Mean Squared Error (RMSE) is the square root of the MSE providing an error measure in the same units as the target variable. This makes it easier to interpret compared to MSE as it returns an error value in the same scale as the dependent variable.
\text{RMSE} = \sqrt{\text{MSE}}
RMSE is important when we need an interpretable measure of error that maintains sensitivity to larger mistakes making it suitable for many regression tasks.
3. R-squared (Coefficient of Determination)
R-squared (R²) measures the proportion of the variance in the dependent variable that is explained by the independent variables in the model. It is a widely used metric for evaluating the explanatory power of a linear regression model. It is useful for linear regression models where the relationship between the variables is non-linear.
4. Huber Loss
Huber Loss is a blend of MSE and MAE which is designed to be less sensitive to outliers while still maintaining the benefits of both. It's useful when datasets contain outliers as it behaves like MSE for small errors and like MAE for large errors.
Understanding different cost functions in linear regression is important for building effective predictive models. Each function has its strengths and weaknesses and the choice of which to use depends on the specific problem we are trying to solve.
By understanding and applying the right cost function we can increase the accuracy and performance of linear regression models which helps in leading to better predictions.
Related articles:
- ML | Common Loss Functions
- Loss Functions in Deep Learning
- Loss Function in TensorFlow
Similar Reads
Machine Learning Tutorial Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.Do you
5 min read
Introduction to Machine Learning
Python for Machine Learning
Machine Learning with Python TutorialPython language is widely used in Machine Learning because it provides libraries like NumPy, Pandas, Scikit-learn, TensorFlow, and Keras. These libraries offer tools and functions essential for data manipulation, analysis, and building machine learning models. It is well-known for its readability an
5 min read
Pandas TutorialPandas is an open-source software library designed for data manipulation and analysis. It provides data structures like series and DataFrames to easily clean, transform and analyze large datasets and integrates with other Python libraries, such as NumPy and Matplotlib. It offers functions for data t
6 min read
NumPy Tutorial - Python LibraryNumPy (short for Numerical Python ) is one of the most fundamental libraries in Python for scientific computing. It provides support for large, multi-dimensional arrays and matrices along with a collection of mathematical functions to operate on arrays.At its core it introduces the ndarray (n-dimens
3 min read
Scikit Learn TutorialScikit-learn (also known as sklearn) is a widely-used open-source Python library for machine learning. It builds on other scientific libraries like NumPy, SciPy and Matplotlib to provide efficient tools for predictive data analysis and data mining.It offers a consistent and simple interface for a ra
3 min read
ML | Data Preprocessing in PythonData preprocessing is a important step in the data science transforming raw data into a clean structured format for analysis. It involves tasks like handling missing values, normalizing data and encoding variables. Mastering preprocessing in Python ensures reliable insights for accurate predictions
6 min read
EDA - Exploratory Data Analysis in PythonExploratory Data Analysis (EDA) is a important step in data analysis which focuses on understanding patterns, trends and relationships through statistical tools and visualizations. Python offers various libraries like pandas, numPy, matplotlib, seaborn and plotly which enables effective exploration
6 min read
Feature Engineering
Supervised Learning
Unsupervised Learning
Model Evaluation and Tuning
Advance Machine Learning Technique
Machine Learning Practice