In this article, we will discuss how to split a dataset using scikit-learns' train_test_split().
The train_test_split() method is used to split our data into train and test sets. First, we need to divide our data into features (X) and labels (y). The dataframe gets divided into X_train, X_test, y_train, and y_test. X_train and y_train sets are used for training and fitting the model. The X_test and y_test sets are used for testing the model if it's predicting the right outputs/labels. we can explicitly test the size of the train and test sets. It is suggested to keep our train sets larger than the test sets.
In this step, we are importing the necessary packages or modules into the working python environment.
Here, we load the CSV using pd.read_csv() method from pandas and get the shape of the data set using the shape() function.
Here, we are assigning the X and the Y variable in which the X feature variable has independent variables and the y feature variable has a dependent variable.
Here, the train_test_split() class from sklearn.model_selection is used to split our data into train and test sets where feature variables are given as input in the method. test_size determines the portion of the data which will go into test sets and a random state is used for data reproducibility.
In this example, 'predictions.csv' file is imported. df.shape attribute is used to retrieve the shape of the data frame. The shape of the dataframe is (13,3). The features columns are taken in the X variable and the outcome column is taken in the y variable. X and y variables are passed in the train_test_split() method to split the data frame into train and test sets. The random state parameter is used for data reproducibility. test_size is given as 0.25 which means 25% of the data goes into the test sets. 4 out of 13 rows in the dataframe go into the test sets. 75% of data goes into the train sets, which is 9 rows out of 13 rows. The train sets are used to fit and train the machine learning model. The test sets are used for evaluation.
array([19.82000933, 14.23636718, 12.80417236, 7.75461569, 8.31672266,
15.4001915 , 11.6590983 , 15.22650923, 15.53524916, 19.46415132,
17.21364106, 16.69603229, 16.46449309, 10.15345178, 13.44695953,
24.71946196, 18.67190453, 15.85505154, 14.45450049, 9.91684409,
10.41647177, 4.61335238, 17.41531451, 17.31014955, 21.72288151,
5.87934089, 11.29101265, 17.88733657, 21.04225992, 12.32251227,
14.4099317 , 15.05829814, 10.2105313 , 7.28532072, 12.66133397,
23.25847491, 18.87101505, 4.55545854, 19.79603707, 9.21203026,
10.24668718, 8.96989469, 13.33515217, 20.69532628, 12.17013119,
21.69572633, 16.7346457 , 22.16358256, 5.34163764, 20.43470231,
7.58252563, 23.38775769, 10.2270323 , 12.33473902, 24.10480458,
9.88919804, 21.7781076 ])
2.7506859249500466