Python - Scipy curve_fit with multiple independent variables
Last Updated :
26 Apr, 2025
Curve fitting examines the relationship between one or more predictors (independent variables) and a response variable (dependent variable), with the goal of defining a "best fit" model of the relationship. It is the process of constructing a mathematical function, that has the best fit to a series of data points possibly subject to constraints.
Curve fit in Python
In Python, we can perform curve fit by using scipy.optimize library.
Syntax:
scipy.optimize.curve_fit(f, xdata, ydata, p0=None, sigma=None, absolute_sigma=False, check_finite=True, bounds=(- inf, inf), method=None, jac=None, *, full_output=False, **kwargs)
Parameters:
- f (callable function): The model function, f(X, . . .). It must take the independent variable as the first argument and the parameters to fit as separate remaining arguments.
- xdata (array_like or object): The independent variable where the data is measured. Should usually be an M-length sequence or an (k,M)-shaped array in case of functions with k predictors (multiple independent variables).
- ydata (array_like): The dependent data, a length M array - nominally f(xdata, . . .).
The curve_fit uses the non-linear least squares method by default to fit a function, f, to the data points.
Defining Model function
We define the function (curve) to which we want to fit our data. Here, a and b are parameters that define the curve. In this example, we choose y=(a(x_2)^2+b(x_2)^2) as our model function.
Python3
def f(X, a, b):
x_1, x_2 = X
return a*x_1**2 + b*x_2**2
Initializing the independent(y) and dependent(X) data
In this step, we initialize the independent data x_1 and x_2 using np.linspace(0, 4, 50) which creates an evenly spaced array over a specified interval. In the case of multiple independent variables X = (x_1, x_2, ). Then we initialize dependent data y using the model function and adding noise (np.random.random(50)*4) to it.
Python3
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
x_1 = np.linspace(0, 4, 50)
x_2 = np.linspace(0, 4, 50)
X = (x_1, x_2)
y = f(X, 2, 4)
# Adding Noise
y = y + np.random.random(50)*4
# plotting the data points
fig = plt.figure()
fig.set_figwidth(40)
fig.set_figheight(10)
ax = plt.axes(projection='3d')
ax.set_xlabel('x_1', fontsize=12, color='green')
ax.set_ylabel('x_2', fontsize=12, color='green')
ax.set_zlabel('y', fontsize=12, color='green')
ax.scatter3D(x_1, x_2, y, color='green')
plt.title("Plot of data points")
plt.show()
Output:
The curve_fit() method returns the following output:
- popt (array): Optimal values for the parameters so that the sum of the squared residuals of f(xdata, *popt) - ydata is minimized.
- pcov2-D (array): The estimated covariance of popt. The diagonals provide the variance of the parameter estimate.
Python3
popt, pcov = curve_fit(f, X, y)
popt
Output:
array([-96.89634526, 103.10365474])
Visualizing Results
Now by using the parameters we have obtained from the curve_fit method plot the curve in the 3-D plane using the plot3D method.
Python3
# plotting the data points and the fitted curve
fig = plt.figure()
fig.set_figwidth(40)
fig.set_figheight(10)
ax = plt.axes(projection='3d')
ax.set_title('Curve fit plot', fontsize=15)
ax.set_xlabel('x_1', fontsize=12, color='green')
ax.set_ylabel('x_2', fontsize=12, color='green')
ax.set_zlabel('y', fontsize=12, color='green')
ax.scatter3D(x_1, x_2, y, color='green')
ax.plot3D(x_1, x_2, popt[0]*(x_1**2)+popt[1]*(x_2**2), color='black')
plt.show()
Output:
We can optimize our solution with the help of other parameters such as p0 and bounds. p0 is the Initial guess for the parameters (length N). If None, then the initial values will all be 1. Bounds are used to set lower and upper bounds on parameters. Defaults to no bounds.
Similar Reads
Differentiate a Legendre series with multidimensional coefficients in Python
In this article, we will cover how to differentiate a Legendre series with a multi-dimensional coefficient array in Python using NumPy. Example Input: [[ 1 Â 2 Â 3 Â 4 Â 5] Â [ 3 Â 4 Â 2 Â 6 Â 7] Â [43 45 Â 2 Â 6 Â 7]] Output: [[ Â 3. Â 4. Â 2. Â 6. Â 7.] Â [129. 135. Â 6. Â 18. Â 21.]] Explanation: Legendre series
3 min read
Python | Scipy stats.hypsecant.fit() method
With the help of stats.hypsecant.fit() method, we can get the parameter estimates for generic data by using stats.hypsecant.fit() method. Syntax : stats.hypsecant.fit(data) Return : Return the parameter estimates for generic data. Example #1 : In this example we can see that by using stats.hypsecant
1 min read
How to Create Subplots in Matplotlib with Python?
Matplotlib is a widely used data visualization library in Python that provides powerful tools for creating a variety of plots. One of the most useful features of Matplotlib is its ability to create multiple subplots within a single figure using the plt.subplots() method. This allows users to display
6 min read
Create Scatter Plot with smooth Line using Python
A curve can be smoothened to reach a well approximated idea of the visualization. In this article, we will be plotting a scatter plot with the smooth line with the help of the SciPy library. To plot a smooth line scatter plot we use the following function: scipy.interpolate.make_interp_spline() from
2 min read
Solve Differential Equations with ODEINT Function of SciPy module in Python
In this post, we are going to learn how to solve differential equations with odeint function of scipy module in Python. ODE stands for Ordinary Differential Equation and refers to those kinds of differential equations that involve derivatives but no partial derivatives. In other words, we only consi
2 min read
Differentiate a Legendre series using NumPy in Python
In this article, we will cover how to differentiate a Legendre series in Python using NumPy. legendre.legder method In python, the Legendre module provides many functions like legder to perform arithmetic, and calculus operations on the Legendre series. It is one of the functions provided by the Leg
3 min read
3D Curve Fitting With Python
Curve fitting is a widely used technique in the field of data analysis and mathematical modeling. It involves the process of finding a mathematical function that best approximates a set of data points. In 3D curve fitting, the process is extended to three-dimensional space, where the goal is to find
7 min read
Multiple Linear Regression With scikit-learn
In this article, let's learn about multiple linear regression using scikit-learn in the Python programming language. Regression is a statistical method for determining the relationship between features and an outcome variable or result. Machine learning, it's utilized as a method for predictive mode
11 min read
Python | Scipy stats.halfgennorm.fit() method
With the help of stats.halfgennorm.fit() method, we can get the parameter estimates for generic data by using stats.halfgennorm.fit() method. Syntax : stats.halfgennorm.fit(data, beta) Return : Return the parameter estimates for generic data. Example #1 : In this example we can see that by using sta
1 min read
Implementing PCA in Python with scikit-learn
Principal Component Analysis (PCA) is a dimensionality reduction technique. It transform high-dimensional data into a smaller number of dimensions called principal components and keeps important information in the data. In this article, we will learn about how we implement PCA in Python using scikit
3 min read