Cross Validation
Cross Validation
Definition
• Cross-validation is a technique used to evaluate the performance of a
machine learning model by partitioning the original dataset into
training and testing subsets. This method helps to ensure that the
model's performance is robust and not dependent on a particular set
of data.
Dry Run of Cross-Validation with K-
Fold
• Context:
• Dataset: Iris dataset with 150 samples and 4 features (sepal length,
sepal width, petal length, petal width).
• Classes: 3 classes (0, 1, 2) representing different species of Iris.
• Model: RandomForestClassifier (default settings).
• Cross-Validation: KFold with 5 splits.
• Load Dataset:
• KFold Definition:
• Training data: 120 samples (since 150 samples split into 5, leaving 4 folds for
training).
• Testing data: 30 samples (1 fold reserved for testing).
• Let's assume the indices for the testing data in the first fold are: [0, 1, 2, ..., 29].
• Calculate the accuracy by comparing the predicted labels with the actual labels.
• Assume the actual labels are [0, 0, 1, 1, 2, ..., 2].
• If 27 out of 30 predictions are correct, the accuracy score for this fold would be
27/30 = 0.90.
• Store Score:
• Testing data: Another set of 30 samples, say indices [30, 31, 32, ..., 59].
• Training data: Remaining 120 samples.
• Training:
• The model is trained again from scratch on these new 120 training samples.
• Prediction:
• Store Score:
• Testing data: Next set of 30 samples, say indices [60, 61, 62, ..., 89].
• Training data: Remaining 120 samples.
• Training:
• Prediction:
• Store Score:
• Testing data: Next set of 30 samples, say indices [90, 91, 92, ..., 119].
• Training data: Remaining 120 samples.
• Training:
• Prediction:
• Store Score:
• Testing data: Last set of 30 samples, say indices [120, 121, 122, ..., 149].
• Training data: Remaining 120 samples.
• Training:
• The model is trained one last time on the new training samples.
• Prediction:
• Store Score:
• 0.00018+0.000384+0.00287+0.00215+0.00018=0.005764
• Variance:
• = 0.005764 / 5 = 0.001153
• Calculate the Standard Deviation:
• Standard deviation is the square root of the variance.
• ≈0.034
Summary