Medical Insurance Price Prediction using Machine Learning - Python
Last Updated :
23 Jul, 2025
You must have heard some advertisements regarding medical insurance that promises to help financially in case of any medical emergency. One who purchases this type of insurance has to pay premiums monthly and this premium amount varies vastly depending upon various factors.

Medical Insurance Price Prediction using Machine Learning in Python
In this article, we will try to extract some insights from a dataset that contains details about the background of a person who is purchasing medical insurance along with what amount of premium is charged to those individuals as well using Machine Learning in Python.
Importing Libraries and Dataset
Python libraries make it very easy for us to handle the data and perform typical and complex tasks with a single line of code.
- Pandas - This library helps to load the data frame in a 2D array format and has multiple functions to perform analysis tasks in one go.
- Numpy - Numpy arrays are very fast and can perform large computations in a very short time.
- Matplotlib/Seaborn - This library is used to draw visualizations.
Python
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as pt
import warnings
warnings.filterwarnings("ignore")
Now let's use the panda's data frame to load the dataset and look at the first five rows of it.
Python
df=pd.read_csv("insurance.csv")
df
Output:
Data Set
Now, we can observe the data and its shape(rows x columns)
This dataset contains 1338 data points with 6 independent features and 1 target feature(charges).
Python
Output:
Details about the columns of the datasetFrom the above, we can see that the dataset contains 2 columns with float values 3 with categorical values and the rest contains integer values.
Python
Output:
Descriptive Statistical measures of the dataWe can look at the descriptive statistical measures of the continuous data available in the dataset.
Exploratory Data Analysis
EDA is an approach to analyzing the data using visual techniques. It is used to discover trends, and patterns, or to check assumptions with the help of statistical summaries and graphical representations. While performing the EDA of this dataset we will try to look at what is the relation between the independent features that is how one affects the other.
Python
Output:
Count of the null values column wiseSo, here we can conclude that there are no null values in the dataset given.
Python
import matplotlib.pyplot as plt
features = ['sex', 'smoker', 'region']
plt.subplots(figsize=(20, 10))
for i, col in enumerate(features):
plt.subplot(1, 3, i + 1)
x = df[col].value_counts()
plt.pie(x.values,
labels=x.index,
autopct='%1.1f%%')
plt.show()
# This code is modified by Susobhan Akhuli
Output:
Pie chart for the sex, smoker, and region columnThe data provided to us is equally distributed among the sex and the region columns but in the smoker column, we can observe a ratio of 80:20.
Python
features = ['sex', 'children', 'smoker', 'region']
plt.subplots(figsize=(20, 10))
for i, col in enumerate(features):
plt.subplot(2, 2, i + 1)
df.groupby(col)['charges'].mean().astype(float).plot.bar()
plt.show()
# This code is modified by Susobhan Akhuli
Output:
Comparison between charges paid between different groupsNow let's look at some of the observations which are shown in the above graphs:
- Charges are on the higher side for males as compared to females but the difference is not that much.
- Premium charged from the smoker is around thrice that which is charged from non-smokers.
- Charges are approximately the same in the given four regions.
Python
import seaborn as sns
features = ['age', 'bmi']
plt.subplots(figsize=(17, 7))
for i, col in enumerate(features):
plt.subplot(1, 2, i + 1)
sns.scatterplot(data=df, x=col,
y='charges',
hue='smoker')
plt.show()
# This code is modified by Susobhan Akhuli
Output:
Scatter plot of the charges paid v/s age and BMI respectivelyA clear distinction can be observed here between the charges that smokers have to pay. Also here as well we can observe that as the age of a person increases premium prices goes up as well.
DATA PREPROCESSING
Data preprocessing is technique to clean the unusual data like the missing values,wrong data,wrong format of data,duplicated data and the outliers.In this data we can observe that there are no missing values and wrong data.The only thing we can need to check is for duplicates and presence of outliers.
Python
df.drop_duplicates(inplace=True)
sns.boxplot(df['age'])
Output:
Boxplot of age
we can see that there are no outliers present in age column
Python
Box plot of bmi
Due to the presence of outliers present in bmi column we need to treat the outliers by replacing the values with mean as the bmi column consists of continuous data.
Python
Q1=df['bmi'].quantile(0.25)
Q2=df['bmi'].quantile(0.5)
Q3=df['bmi'].quantile(0.75)
iqr=Q3-Q1
lowlim=Q1-1.5*iqr
upplim=Q3+1.5*iqr
print(lowlim)
print(upplim)
Output:
13.674999999999994
47.31500000000001
Python
!pip install feature_engine
from feature_engine.outliers import ArbitraryOutlierCapper
arb=ArbitraryOutlierCapper(min_capping_dict={'bmi':13.6749},max_capping_dict={'bmi':47.315})
df[['bmi']]=arb.fit_transform(df[['bmi']])
sns.boxplot(df['bmi'])
# This code is modified by Susobhan Akhuli
Box plot for bmi
Now we successfully treated the outliers .
Data Wrangling
Data wrangling is a technique to ensure whether the data follow normal or standard distribution and encode the discrete data for prediction.
Python
df['bmi'].skew()
df['age'].skew()
Output:
0.23289153320569975
0.054780773126998195
Data in both the age and BMI column approximately follow a Normal distribution which is a good point with respect to the model's learning.
Encoding
encoding is to be done for discrete categorical data (sex,bmi,region).
Python
df['sex']=df['sex'].map({'male':0,'female':1})
df['smoker']=df['smoker'].map({'yes':1,'no':0})
df['region']=df['region'].map({'northwest':0, 'northeast':1,'southeast':2,'southwest':3})
Output:
Encoded Data
Now the discrete data is encoded and the data preprocessing and data wrangling part is completed.Now we can go for model development.
Output:
Python
correlation matrix
Model Development
There are so many state-of-the-art ML models available in academia but some model fits better to some problem while some fit better than other. So, to make this decision we split our data into training and validation data. Then we use the validation data to choose the model with the highest performance.
Python
X=df.drop(['charges'],axis=1)
Y=df[['charges']]
from sklearn.linear_model import LinearRegression,Lasso
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import GradientBoostingRegressor
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
l1=[]
l2=[]
l3=[]
cvs=0
for i in range(40,50):
xtrain,xtest,ytrain,ytest=train_test_split(X,Y,test_size=0.2,random_state=i)
lrmodel=LinearRegression()
lrmodel.fit(xtrain,ytrain)
l1.append(lrmodel.score(xtrain,ytrain))
l2.append(lrmodel.score(xtest,ytest))
cvs=(cross_val_score(lrmodel,X,Y,cv=5,)).mean()
l3.append(cvs)
df1=pd.DataFrame({'train acc':l1,'test acc':l2,'cvs':l3})
df1
# This code is modified by Susobhan Akhuli
Output:
Scores for various random_state number
After dividing the data into training and validation data it is considered a better practice to achieve stable and fast training of the model.We have identified the best random_state number for this data set as 42 .Now we fix this random_state and try with different ml algorithms for better score or accuracy.
Now let's train some state-of-the-art machine learning models on the training data and then use the validation data for choosing the best out of them for prediction.
Python
xtrain,xtest,ytrain,ytest=train_test_split(X,Y,test_size=0.2,random_state=42)
lrmodel=LinearRegression()
lrmodel.fit(xtrain,ytrain)
print(lrmodel.score(xtrain,ytrain))
print(lrmodel.score(xtest,ytest))
print(cross_val_score(lrmodel,X,Y,cv=5,).mean())
Output:
Linear Regression:
0.7295415541376445
0.8062391115570589
0.7470697972809902
Python
from sklearn.metrics import r2_score
svrmodel=SVR()
svrmodel.fit(xtrain,ytrain)
ypredtrain1=svrmodel.predict(xtrain)
ypredtest1=svrmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain1))
print(r2_score(ytest,ypredtest1))
print(cross_val_score(svrmodel,X,Y,cv=5,).mean())
Output:
SVR:
-0.10151474302536445
-0.1344454720199666
-0.10374591327267262
Python
rfmodel=RandomForestRegressor(random_state=42)
rfmodel.fit(xtrain,ytrain)
ypredtrain2=rfmodel.predict(xtrain)
ypredtest2=rfmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain2))
print(r2_score(ytest,ypredtest2))
print(cross_val_score(rfmodel,X,Y,cv=5,).mean())
from sklearn.model_selection import GridSearchCV
estimator=RandomForestRegressor(random_state=42)
param_grid={'n_estimators':[10,40,50,98,100,120,150]}
grid=GridSearchCV(estimator,param_grid,scoring="r2",cv=5)
grid.fit(xtrain,ytrain)
print(grid.best_params_)
rfmodel=RandomForestRegressor(random_state=42,n_estimators=120)
rfmodel.fit(xtrain,ytrain)
ypredtrain2=rfmodel.predict(xtrain)
ypredtest2=rfmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain2))
print(r2_score(ytest,ypredtest2))
print(cross_val_score(rfmodel,X,Y,cv=5,).mean())
Output:
RandomForestRegressor:
0.9738163260247533
0.8819423353068565
0.8363637309718952
Hyperparametertuning:
{'n_estimators': 120}
0.9746383984429655
0.8822009842175969
0.8367438097052858
Python
gbmodel=GradientBoostingRegressor()
gbmodel.fit(xtrain,ytrain)
ypredtrain3=gbmodel.predict(xtrain)
ypredtest3=gbmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain3))
print(r2_score(ytest,ypredtest3))
print(cross_val_score(gbmodel,X,Y,cv=5,).mean())
from sklearn.model_selection import GridSearchCV
estimator=GradientBoostingRegressor()
param_grid={'n_estimators':[10,15,19,20,21,50],'learning_rate':[0.1,0.19,0.2,0.21,0.8,1]}
grid=GridSearchCV(estimator,param_grid,scoring="r2",cv=5)
grid.fit(xtrain,ytrain)
print(grid.best_params_)
gbmodel=GradientBoostingRegressor(n_estimators=19,learning_rate=0.2)
gbmodel.fit(xtrain,ytrain)
ypredtrain3=gbmodel.predict(xtrain)
ypredtest3=gbmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain3))
print(r2_score(ytest,ypredtest3))
print(cross_val_score(gbmodel,X,Y,cv=5,).mean())
Output:
GradientBoostingRegressor:
0.8931345821166041
0.904261922040551
0.8549940291799407
Hyperparametertuning
{'learning_rate': 0.2, 'n_estimators': 21}
0.8682397447116927
0.9017109716082661
0.8606041910125791
Python
xgmodel=XGBRegressor()
xgmodel.fit(xtrain,ytrain)
ypredtrain4=xgmodel.predict(xtrain)
ypredtest4=xgmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain4))
print(r2_score(ytest,ypredtest4))
print(cross_val_score(xgmodel,X,Y,cv=5,).mean())
from sklearn.model_selection import GridSearchCV
estimator=XGBRegressor()
param_grid={'n_estimators':[10,15,20,40,50],'max_depth':[3,4,5],'gamma':[0,0.15,0.3,0.5,1]}
grid=GridSearchCV(estimator,param_grid,scoring="r2",cv=5)
grid.fit(xtrain,ytrain)
print(grid.best_params_)
xgmodel=XGBRegressor(n_estimators=15,max_depth=3,gamma=0)
xgmodel.fit(xtrain,ytrain)
ypredtrain4=xgmodel.predict(xtrain)
ypredtest4=xgmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain4))
print(r2_score(ytest,ypredtest4))
print(cross_val_score(xgmodel,X,Y,cv=5,).mean())
Output:
XGBRegressor:
0.9944530188818493
0.8618686915522016
0.8104424308304893
Hyperparametertuning:
{'gamma': 0, 'max_depth': 3, 'n_estimators': 15}
0.870691899927822
0.904151903449132
0.8600710679082143
Comapring All Models
LinearRegression | 0.729 | 0.806 | 0.747 |
---|
Model | Train Accuracy | Test Accuracy | CV Score |
---|
SupportVectorMachine | -0.105 | -0.134 | 0.103 |
---|
RandomForest | 0.974 | 0.882 | 0.836 |
---|
GradientBoost | 0.868 | 0.901 | 0.860 |
---|
XGBoost | 0.870 | 0.904 | 0.860 |
---|
From the above table we can observe that XGBoost is the best model.Now we need to identify the important features for predicting of charges.
Python
feats=pd.DataFrame(data=grid.best_estimator_.feature_importances_,index=X.columns,columns=['Importance'])
feats
Output:
Features Importance
Python
important_features=feats[feats['Importance']>0.01]
important_features
Output:
Important Features
Final Model:
Python
df.drop(df[['sex','region']],axis=1,inplace=True)
Xf=df.drop(df[['charges']],axis=1)
X=df.drop(df[['charges']],axis=1)
xtrain,xtest,ytrain,ytest=train_test_split(Xf,Y,test_size=0.2,random_state=42)
finalmodel=XGBRegressor(n_estimators=15,max_depth=3,gamma=0)
finalmodel.fit(xtrain,ytrain)
ypredtrain4=finalmodel.predict(xtrain)
ypredtest4=finalmodel.predict(xtest)
print(r2_score(ytrain,ypredtrain4))
print(r2_score(ytest,ypredtest4))
print(cross_val_score(finalmodel,X,Y,cv=5,).mean())
Final Model:
Train accuracy : 0.870691899927822
Test accuracy : 0.904151903449132
CV Score : 0.8600710679082143
Save Model:
Python
from pickle import dump
dump(finalmodel,open('insurancemodelf.pkl','wb'))
Predict on new data:
Python
new_data=pd.DataFrame({'age':19,'sex':'male','bmi':27.9,'children':0,'smoker':'yes','region':'northeast'},index=[0])
new_data['smoker']=new_data['smoker'].map({'yes':1,'no':0})
new_data=new_data.drop(new_data[['sex','region']],axis=1)
finalmodel.predict(new_data)
Output:
array([17483.12], dtype=float32
Colab Link
To get the colab notebook, click here. For dataset, click here.
Conclusion
Out of all the models XGBoost model is giving the highest accuracy this means predictions made by this model are close to the real values as compared to the other model.
The dataset we have used here was small still the conclusion we drew from them were quite similar to what is observed in the real-life scenario. If we would have a bigger dataset then we will be able to learn even deeper patterns in the relation between the independent features and the premium charged from the buyers.
Get complete notebook link here :
Notebook: click here.
Dataset: click here.
Medical Insurance Price Prediction using Machine Learning in Python
Similar Reads
Machine Learning Tutorial Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.Do you
5 min read
Introduction to Machine Learning
Python for Machine Learning
Machine Learning with Python TutorialPython language is widely used in Machine Learning because it provides libraries like NumPy, Pandas, Scikit-learn, TensorFlow, and Keras. These libraries offer tools and functions essential for data manipulation, analysis, and building machine learning models. It is well-known for its readability an
5 min read
Pandas TutorialPandas is an open-source software library designed for data manipulation and analysis. It provides data structures like series and DataFrames to easily clean, transform and analyze large datasets and integrates with other Python libraries, such as NumPy and Matplotlib. It offers functions for data t
6 min read
NumPy Tutorial - Python LibraryNumPy (short for Numerical Python ) is one of the most fundamental libraries in Python for scientific computing. It provides support for large, multi-dimensional arrays and matrices along with a collection of mathematical functions to operate on arrays.At its core it introduces the ndarray (n-dimens
3 min read
Scikit Learn TutorialScikit-learn (also known as sklearn) is a widely-used open-source Python library for machine learning. It builds on other scientific libraries like NumPy, SciPy and Matplotlib to provide efficient tools for predictive data analysis and data mining.It offers a consistent and simple interface for a ra
3 min read
ML | Data Preprocessing in PythonData preprocessing is a important step in the data science transforming raw data into a clean structured format for analysis. It involves tasks like handling missing values, normalizing data and encoding variables. Mastering preprocessing in Python ensures reliable insights for accurate predictions
6 min read
EDA - Exploratory Data Analysis in PythonExploratory Data Analysis (EDA) is a important step in data analysis which focuses on understanding patterns, trends and relationships through statistical tools and visualizations. Python offers various libraries like pandas, numPy, matplotlib, seaborn and plotly which enables effective exploration
6 min read
Feature Engineering
Supervised Learning
Unsupervised Learning
Model Evaluation and Tuning
Advance Machine Learning Technique
Machine Learning Practice