Pyspark | Linear regression using Apache MLlib
Last Updated :
19 Jul, 2019
Problem Statement: Build a predictive Model for the shipping company, to find an estimate of how many Crew members a ship requires.
The dataset contains 159 instances with 9 features.
The
Description of dataset is as below:

Let’s make the Linear Regression Model, predicting Crew members
Attached dataset:
cruise_ship_info
Python
import pyspark
from pyspark.sql import SparkSession
#SparkSession is now the entry point of Spark
#SparkSession can also be construed as gateway to spark libraries
#create instance of spark class
spark=SparkSession.builder.appName('housing_price_model').getOrCreate()
#create spark dataframe of input csv file
df=spark.read.csv('D:\python coding\pyspark_tutorial\Linear regression\cruise_ship_info.csv'
,inferSchema=True,header=True)
df.show(10)
Output :
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
| Ship_name|Cruise_line|Age| Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
| Journey| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55|
| Quest| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55|
|Celebration| Carnival| 26| 47.262| 14.86| 7.22| 7.43| 31.8| 6.7|
| Conquest| Carnival| 11| 110.0| 29.74| 9.53| 14.88| 36.99|19.1|
| Destiny| Carnival| 17| 101.353| 26.42| 8.92| 13.21| 38.36|10.0|
| Ecstasy| Carnival| 22| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2|
| Elation| Carnival| 15| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2|
| Fantasy| Carnival| 23| 70.367| 20.56| 8.55| 10.22| 34.23| 9.2|
|Fascination| Carnival| 19| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2|
| Freedom| Carnival| 6|110.23899999999999| 37.0| 9.51| 14.87| 29.79|11.5|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
[sourcecode]
#prints structure of dataframe along with datatype
df.printSchema()
[/sourcecode]
Output :
Python
#In our predictive model, below are the columns
df.columns
Output :
Python
#columns identified as features are as below:
#['Cruise_line','Age','Tonnage','passengers','length','cabins','passenger_density']
#to work on the features, spark MLlib expects every value to be in numeric form
#feature 'Cruise_line is string datatype
#using StringIndexer, string type will be typecast to numeric datatype
#import library strinindexer for typecasting
from pyspark.ml.feature import StringIndexer
indexer=StringIndexer(inputCol='Cruise_line',outputCol='cruise_cat')
indexed=indexer.fit(df).transform(df)
#above code will convert string to numeric feature and create a new dataframe
#new dataframe contains a new feature 'cruise_cat' and can be used further
#feature cruise_cat is now vectorized and can be used to fed to model
for item in indexed.head(5):
print(item)
print('\n')
Output :
Row(Ship_name='Journey', Cruise_line='Azamara', Age=6,
Tonnage=30.276999999999997, passengers=6.94, length=5.94,
cabins=3.55, passenger_density=42.64, crew=3.55, cruise_cat=16.0)
Row(Ship_name='Quest', Cruise_line='Azamara', Age=6,
Tonnage=30.276999999999997, passengers=6.94, length=5.94,
cabins=3.55, passenger_density=42.64, crew=3.55, cruise_cat=16.0)
Row(Ship_name='Celebration', Cruise_line='Carnival', Age=26,
Tonnage=47.262, passengers=14.86, length=7.22,
cabins=7.43, passenger_density=31.8, crew=6.7, cruise_cat=1.0)
Row(Ship_name='Conquest', Cruise_line='Carnival', Age=11,
Tonnage=110.0, passengers=29.74, length=9.53,
cabins=14.88, passenger_density=36.99, crew=19.1, cruise_cat=1.0)
Row(Ship_name='Destiny', Cruise_line='Carnival', Age=17,
Tonnage=101.353, passengers=26.42, length=8.92,
cabins=13.21, passenger_density=38.36, crew=10.0, cruise_cat=1.0)
Python
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
#creating vectors from features
#Apache MLlib takes input if vector form
assembler=VectorAssembler(inputCols=['Age',
'Tonnage',
'passengers',
'length',
'cabins',
'passenger_density',
'cruise_cat'],outputCol='features')
output=assembler.transform(indexed)
output.select('features','crew').show(5)
#output as below
Output :
Python
#final data consist of features and label which is crew.
final_data=output.select('features','crew')
#splitting data into train and test
train_data,test_data=final_data.randomSplit([0.7,0.3])
train_data.describe().show()
Output :
Python
test_data.describe().show()
Output :
Python
#import LinearRegression library
from pyspark.ml.regression import LinearRegression
#creating an object of class LinearRegression
#object takes features and label as input arguments
ship_lr=LinearRegression(featuresCol='features',labelCol='crew')
#pass train_data to train model
trained_ship_model=ship_lr.fit(train_data)
#evaluating model trained for Rsquared error
ship_results=trained_ship_model.evaluate(train_data)
print('Rsquared Error :',ship_results.r2)
#R2 value shows accuracy of model is 92%
#model accuracy is very good and can be use for predictive analysis
Output :
Python
#testing Model on unlabeled data
#create unlabeled data from test_data
#testing model on unlabeled data
unlabeled_data=test_data.select('features')
unlabeled_data.show(5)
Output :
Python
predictions=trained_ship_model.transform(unlabeled_data)
predictions.show()
#below are the results of output from test data
Output :
Similar Reads
House Price Prediction using Linear Regression | Django
In this article, we'll explore how to use a Python machine-learning algorithm called linear regression to estimate house prices. We'll do this by taking input data from users who want to predict the price of their home. To make things more accessible and interactive, we'll transform this house price
4 min read
Linear Regression in Python using Statsmodels
In this article, we will discuss how to use statsmodels using Linear Regression in Python. Linear regression analysis is a statistical technique for predicting the value of one variable(dependent variable) based on the value of another(independent variable). The dependent variable is the variable th
4 min read
Identify corrupted records in a dataset using pyspark
There can be datasets that may contain corrupt records. Those records don't follow data-specific rules that are followed by correct records e.g., a corrupt record may have been delimited with a pipe ("|") character but the rest of other records are delimited by comma (","), and it is mentioned to re
4 min read
Scatter Plot with Regression Line using Altair in Python
Prerequisite: Altair In this article, we are going to discuss how to plot to scatter plots with a regression line using the Altair library. Scatter Plot and Regression Line The values of two different numeric variables is represented by dots or circle in Scatter Plot. Scatter Plot is also known as a
4 min read
Removing Blank Strings from a PySpark Dataframe
Cleaning and preprocessing data is a crucial step before it can be used for analysis or modeling. One of the common tasks in data preparation is removing empty strings from a Spark dataframe. A Spark dataframe is a distributed collection of data that is organized into rows and columns. It can be pro
4 min read
How to get distinct rows in dataframe using PySpark?
In this article we are going to get the distinct data from pyspark dataframe in Python, So we are going to create the dataframe using a nested list and get the distinct data. We are going to create a dataframe from pyspark list bypassing the list to the createDataFrame() method from pyspark, then by
2 min read
PySpark Row using on DataFrame and RDD
You can access the rows in the data frame like this: Attribute, dictionary value. Row allows you to create row objects using named arguments. A named argument cannot be omitted to indicate that the value is "none" or does not exist. In this case, you should explicitly set this to None. Subsequent ch
6 min read
Convert pair to value using map() in Pyspark
In this article, we are going to learn how to use map() to convert (key, value) pair to value and keys only using Pyspark in Python. PySpark is the Python library for Spark programming. It is an API for interacting with the Spark cluster using the Python programming language. PySpark provides a simp
3 min read
How to Perform Quadratic Regression in Python?
The quadratic equation is a method of modeling a relationship between sets of independent variables is quadratic regression or we can say the technique of obtaining the equation of a parabola that best fits a collection of data is known as quadratic regression. We use the R square metric to measure
2 min read
How to Check PySpark Version
Knowing the version of PySpark you're working with is crucial for compatibility and troubleshooting purposes. In this article, we will walk through the steps to check the PySpark version in the environment.What is PySpark?PySpark is the Python API for Apache Spark, a powerful distributed computing s
3 min read