{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Multi-class AdaBoosted Decision Trees\n\nThis example reproduces Figure 1 of Zhu et al [1]_ and shows how boosting can\nimprove prediction accuracy on a multi-class problem. The classification\ndataset is constructed by taking a ten-dimensional standard normal distribution\nand defining three classes separated by nested concentric ten-dimensional\nspheres such that roughly equal numbers of samples are in each class (quantiles\nof the $\\chi^2$ distribution).\n\nThe performance of the SAMME and SAMME.R [1]_ algorithms are compared. SAMME.R\nuses the probability estimates to update the additive model, while SAMME  uses\nthe classifications only. As the example illustrates, the SAMME.R algorithm\ntypically converges faster than SAMME, achieving a lower test error with fewer\nboosting iterations. The error of each algorithm on the test set after each\nboosting iteration is shown on the left, the classification error on the test\nset of each tree is shown in the middle, and the boost weight of each tree is\nshown on the right. All trees have a weight of one in the SAMME.R algorithm and\ntherefore are not shown.\n\n.. [1] J. Zhu, H. Zou, S. Rosset, T. Hastie, \"Multi-class AdaBoost\", 2009.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Author: Noel Dawe <noel.dawe@gmail.com>\n#\n# License: BSD 3 clause\n\nimport matplotlib.pyplot as plt\n\nfrom sklearn.datasets import make_gaussian_quantiles\nfrom sklearn.ensemble import AdaBoostClassifier\nfrom sklearn.metrics import accuracy_score\nfrom sklearn.tree import DecisionTreeClassifier\n\n\nX, y = make_gaussian_quantiles(\n    n_samples=13000, n_features=10, n_classes=3, random_state=1\n)\n\nn_split = 3000\n\nX_train, X_test = X[:n_split], X[n_split:]\ny_train, y_test = y[:n_split], y[n_split:]\n\nbdt_real = AdaBoostClassifier(\n    DecisionTreeClassifier(max_depth=2), n_estimators=300, learning_rate=1\n)\n\nbdt_discrete = AdaBoostClassifier(\n    DecisionTreeClassifier(max_depth=2),\n    n_estimators=300,\n    learning_rate=1.5,\n    algorithm=\"SAMME\",\n)\n\nbdt_real.fit(X_train, y_train)\nbdt_discrete.fit(X_train, y_train)\n\nreal_test_errors = []\ndiscrete_test_errors = []\n\nfor real_test_predict, discrete_train_predict in zip(\n    bdt_real.staged_predict(X_test), bdt_discrete.staged_predict(X_test)\n):\n    real_test_errors.append(1.0 - accuracy_score(real_test_predict, y_test))\n    discrete_test_errors.append(1.0 - accuracy_score(discrete_train_predict, y_test))\n\nn_trees_discrete = len(bdt_discrete)\nn_trees_real = len(bdt_real)\n\n# Boosting might terminate early, but the following arrays are always\n# n_estimators long. We crop them to the actual number of trees here:\ndiscrete_estimator_errors = bdt_discrete.estimator_errors_[:n_trees_discrete]\nreal_estimator_errors = bdt_real.estimator_errors_[:n_trees_real]\ndiscrete_estimator_weights = bdt_discrete.estimator_weights_[:n_trees_discrete]\n\nplt.figure(figsize=(15, 5))\n\nplt.subplot(131)\nplt.plot(range(1, n_trees_discrete + 1), discrete_test_errors, c=\"black\", label=\"SAMME\")\nplt.plot(\n    range(1, n_trees_real + 1),\n    real_test_errors,\n    c=\"black\",\n    linestyle=\"dashed\",\n    label=\"SAMME.R\",\n)\nplt.legend()\nplt.ylim(0.18, 0.62)\nplt.ylabel(\"Test Error\")\nplt.xlabel(\"Number of Trees\")\n\nplt.subplot(132)\nplt.plot(\n    range(1, n_trees_discrete + 1),\n    discrete_estimator_errors,\n    \"b\",\n    label=\"SAMME\",\n    alpha=0.5,\n)\nplt.plot(\n    range(1, n_trees_real + 1), real_estimator_errors, \"r\", label=\"SAMME.R\", alpha=0.5\n)\nplt.legend()\nplt.ylabel(\"Error\")\nplt.xlabel(\"Number of Trees\")\nplt.ylim((0.2, max(real_estimator_errors.max(), discrete_estimator_errors.max()) * 1.2))\nplt.xlim((-20, len(bdt_discrete) + 20))\n\nplt.subplot(133)\nplt.plot(range(1, n_trees_discrete + 1), discrete_estimator_weights, \"b\", label=\"SAMME\")\nplt.legend()\nplt.ylabel(\"Weight\")\nplt.xlabel(\"Number of Trees\")\nplt.ylim((0, discrete_estimator_weights.max() * 1.2))\nplt.xlim((-20, n_trees_discrete + 20))\n\n# prevent overlapping y-axis labels\nplt.subplots_adjust(wspace=0.25)\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}