PySpark - Split dataframe into equal number of rows
Last Updated :
18 Jul, 2021
When there is a huge dataset, it is better to split them into equal chunks and then process each dataframe individually. This is possible if the operation on the dataframe is independent of the rows. Each chunk or equally split dataframe then can be processed parallel making use of the resources more efficiently. In this article, we will discuss how to split PySpark dataframes into an equal number of rows.
Creating Dataframe for demonstration:
Python
# importing module
import pyspark
# importing sparksession from pyspark.sql module
from pyspark.sql import SparkSession
# creating sparksession and giving an app name
spark = SparkSession.builder.appName('sparkdf').getOrCreate()
# Column names for the dataframe
columns = ["Brand", "Product"]
# Row data for the dataframe
data = [
("HP", "Laptop"),
("Lenovo", "Mouse"),
("Dell", "Keyboard"),
("Samsung", "Monitor"),
("MSI", "Graphics Card"),
("Asus", "Motherboard"),
("Gigabyte", "Motherboard"),
("Zebronics", "Cabinet"),
("Adata", "RAM"),
("Transcend", "SSD"),
("Kingston", "HDD"),
("Toshiba", "DVD Writer")
]
# Create the dataframe using the above values
prod_df = spark.createDataFrame(data=data,
schema=columns)
# View the dataframe
prod_df.show()
Output:

In the above code block, we have defined the schema structure for the dataframe and provided sample data. Our dataframe consists of 2 string-type columns with 12 records.
Example 1: Split dataframe using 'DataFrame.limit()'
We will make use of the split() method to create 'n' equal dataframes.
Syntax: DataFrame.limit(num)
Where, Limits the result count to the number specified.
Code:
Python
# Define the number of splits you want
n_splits = 4
# Calculate count of each dataframe rows
each_len = prod_df.count() // n_splits
# Create a copy of original dataframe
copy_df = prod_df
# Iterate for each dataframe
i = 0
while i < n_splits:
# Get the top `each_len` number of rows
temp_df = copy_df.limit(each_len)
# Truncate the `copy_df` to remove
# the contents fetched for `temp_df`
copy_df = copy_df.subtract(temp_df)
# View the dataframe
temp_df.show(truncate=False)
# Increment the split number
i += 1
Output:

Example 2: Split the dataframe, perform the operation and concatenate the result
We will now split the dataframe in 'n' equal parts and perform concatenation operation on each of these parts individually and then concatenate the result to a `result_df`. This is to demonstrate how we can use the extension of the previous code to perform a dataframe operation separately on each dataframe and then append these individual dataframes to produce a new dataframe which has a length equal to the original dataframe.
Python
# Define the number of splits you want
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.functions import concat, col, lit
n_splits = 4
# Calculate count of each dataframe rows
each_len = prod_df.count() // n_splits
# Create a copy of original dataframe
copy_df = prod_df
# Function to modify columns of each individual split
def modify_dataframe(data):
return data.select(
concat(col("Brand"), lit(" - "),
col("Product"))
)
# Create an empty dataframe to
# store concatenated results
schema = StructType([
StructField('Brand - Product', StringType(), True)
])
result_df = spark.createDataFrame(data=[],
schema=schema)
# Iterate for each dataframe
i = 0
while i < n_splits:
# Get the top `each_len` number of rows
temp_df = copy_df.limit(each_len)
# Truncate the `copy_df` to remove
# the contents fetched for `temp_df`
copy_df = copy_df.subtract(temp_df)
# Perform operation on the newly created dataframe
temp_df_mod = modify_dataframe(data=temp_df)
temp_df_mod.show(truncate=False)
# Concat the dataframe
result_df = result_df.union(temp_df_mod)
# Increment the split number
i += 1
result_df.show(truncate=False)
Output:
Similar Reads
Split Dataframe in Row Index in Pyspark In this article, we are going to learn about splitting Pyspark data frame by row index in Python. In data science. there is a bulk of data and their is need of data processing and lots of modules, functions and methods are available to process data. In this article we are going to process data by sp
5 min read
Get number of rows and columns of PySpark dataframe In this article, we will discuss how to get the number of rows and the number of columns of a PySpark dataframe. For finding the number of rows and number of columns we will use count() and columns() with len() function respectively. df.count(): This function is used to extract number of rows from t
6 min read
PySpark DataFrame - Drop Rows with NULL or None Values Sometimes while handling data inside a dataframe we may get null values. In order to clean the dataset we have to remove all the null values in the dataframe. So in this article, we will learn how to drop rows with NULL or None Values in PySpark DataFrame. Function Used In pyspark the drop() funct
5 min read
How to duplicate a row N time in Pyspark dataframe? In this article, we are going to learn how to duplicate a row N times in a PySpark DataFrame. Method 1: Repeating rows based on column value In this method, we will first make a PySpark DataFrame using createDataFrame(). In our example, the column "Y" has a numerical value that can only be used here
4 min read
How to get distinct rows in dataframe using PySpark? In this article we are going to get the distinct data from pyspark dataframe in Python, So we are going to create the dataframe using a nested list and get the distinct data. We are going to create a dataframe from pyspark list bypassing the list to the createDataFrame() method from pyspark, then by
2 min read
Get specific row from PySpark dataframe In this article, we will discuss how to get the specific row from the PySpark dataframe. Creating Dataframe for demonstration: Python3 # importing module import pyspark # importing sparksession # from pyspark.sql module from pyspark.sql import SparkSession # creating sparksession # and giving an app
4 min read