{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# DL8.5 Predictor used for task involving label\nPyDL8.5 allows users to write their own error function. This example shows how \nto write an error function based on transaction identifiers. PyDL8.5 will determine\nthese transaction identifiers (tids) based on the occurrences of an itemset in the\ntraining data. \n\nThe error function is called very often, and calculating an error score based\non tids can be time consuming. For classification tasks, it is highly recommended\nnot to write an error function in Python that operators on lists of tids. \ncheck the plot_classifier_user_1.py example for a more efficient user-written\nerror function in classification settings.\n\nMoreover, this example shows how to use the DL85Predictor class. Using this\nclass, the labels do not need to be passed to DL8.5. For classification tasks,\nthe error function can also be specififed as a parameter of the DL85Classifier\nclass. In this case, a standard implementation is used for filling in the\nclass labels for leafs in the tree. \n\nAnother example of a user-specified error function is given in plot_cluster_user.py.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom sklearn.metrics import confusion_matrix\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import accuracy_score\nimport time\nfrom dl85 import DL85Predictor\n\ndataset = np.genfromtxt(\"../datasets/anneal.txt\", delimiter=' ')\nX, y = dataset[:, 1:], dataset[:, 0]\nX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n\n\nprint(\"#####################################################################################\\n\"\n      \"#      DL8.5 Predictor used for classification : user specific error function       #\\n\"\n      \"#####################################################################################\")\n\n\n# return the error and the majority class\ndef error(tids):\n    classes, supports = np.unique(y_train.take(list(tids)), return_counts=True)\n    maxindex = np.argmax(supports)\n    return sum(supports) - supports[maxindex], classes[maxindex]\n\n\nclf = DL85Predictor(max_depth=2, error_function=error, time_limit=600)\nstart = time.perf_counter()\nprint(\"Model building...\")\nclf.fit(X_train)\nduration = time.perf_counter() - start\nprint(\"Model built. Duration of building =\", round(duration, 4))\nprint(\"Accuracy DL8.5 on training set =\", round(clf.accuracy_, 4))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}