{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Decision Tree Regression with AdaBoost\n\nA decision tree is boosted using the AdaBoost.R2 [1]_ algorithm on a 1D\nsinusoidal dataset with a small amount of Gaussian noise.\n299 boosts (300 decision trees) is compared with a single decision tree\nregressor. As the number of boosts is increased the regressor can fit more\ndetail.\n\n.. [1] H. Drucker, \"Improving Regressors using Boosting Techniques\", 1997.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Noel Dawe \n#\n# License: BSD 3 clause\n\n# importing necessary libraries\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom sklearn.tree import DecisionTreeRegressor\nfrom sklearn.ensemble import AdaBoostRegressor\n\n# Create the dataset\nrng = np.random.RandomState(1)\nX = np.linspace(0, 6, 100)[:, np.newaxis]\ny = np.sin(X).ravel() + np.sin(6 * X).ravel() + rng.normal(0, 0.1, X.shape[0])\n\n# Fit regression model\nregr_1 = DecisionTreeRegressor(max_depth=4)\n\nregr_2 = AdaBoostRegressor(\n DecisionTreeRegressor(max_depth=4), n_estimators=300, random_state=rng\n)\n\nregr_1.fit(X, y)\nregr_2.fit(X, y)\n\n# Predict\ny_1 = regr_1.predict(X)\ny_2 = regr_2.predict(X)\n\n# Plot the results\nplt.figure()\nplt.scatter(X, y, c=\"k\", label=\"training samples\")\nplt.plot(X, y_1, c=\"g\", label=\"n_estimators=1\", linewidth=2)\nplt.plot(X, y_2, c=\"r\", label=\"n_estimators=300\", linewidth=2)\nplt.xlabel(\"data\")\nplt.ylabel(\"target\")\nplt.title(\"Boosted Decision Tree Regression\")\nplt.legend()\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 0 }