{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Robust linear model estimation using RANSAC\n\nIn this example we see how to robustly fit a linear model to faulty data using\nthe RANSAC algorithm.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nfrom matplotlib import pyplot as plt\n\nfrom sklearn import linear_model, datasets\n\n\nn_samples = 1000\nn_outliers = 50\n\n\nX, y, coef = datasets.make_regression(\n n_samples=n_samples,\n n_features=1,\n n_informative=1,\n noise=10,\n coef=True,\n random_state=0,\n)\n\n# Add outlier data\nnp.random.seed(0)\nX[:n_outliers] = 3 + 0.5 * np.random.normal(size=(n_outliers, 1))\ny[:n_outliers] = -3 + 10 * np.random.normal(size=n_outliers)\n\n# Fit line using all data\nlr = linear_model.LinearRegression()\nlr.fit(X, y)\n\n# Robustly fit linear model with RANSAC algorithm\nransac = linear_model.RANSACRegressor()\nransac.fit(X, y)\ninlier_mask = ransac.inlier_mask_\noutlier_mask = np.logical_not(inlier_mask)\n\n# Predict data of estimated models\nline_X = np.arange(X.min(), X.max())[:, np.newaxis]\nline_y = lr.predict(line_X)\nline_y_ransac = ransac.predict(line_X)\n\n# Compare estimated coefficients\nprint(\"Estimated coefficients (true, linear regression, RANSAC):\")\nprint(coef, lr.coef_, ransac.estimator_.coef_)\n\nlw = 2\nplt.scatter(\n X[inlier_mask], y[inlier_mask], color=\"yellowgreen\", marker=\".\", label=\"Inliers\"\n)\nplt.scatter(\n X[outlier_mask], y[outlier_mask], color=\"gold\", marker=\".\", label=\"Outliers\"\n)\nplt.plot(line_X, line_y, color=\"navy\", linewidth=lw, label=\"Linear regressor\")\nplt.plot(\n line_X,\n line_y_ransac,\n color=\"cornflowerblue\",\n linewidth=lw,\n label=\"RANSAC regressor\",\n)\nplt.legend(loc=\"lower right\")\nplt.xlabel(\"Input\")\nplt.ylabel(\"Response\")\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 0 }