{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# DL8.5 classifier : user specific error function based on transactions ID\nPyDL8.5 allows users to write their own error function. This example shows how \nto write an error function based on transaction identifiers. PyDL8.5 will determine\nthese transaction identifiers based on the occurrences of an itemset in the\ntraining data. \n\nThe error function is called very often, and calculating an error score based\non tids can be time consuming. For classification tasks, it is highly recommended\nnot to write an error function in Python that traverses the tids. \ncheck the plot_classifier_user_1.py example for a more efficient user-written\nerror function in classification settings.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom sklearn.metrics import confusion_matrix\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import accuracy_score\nimport time\nfrom dl85 import DL85Classifier\n\ndataset = np.genfromtxt(\"../datasets/anneal.txt\", delimiter=' ')\nX, y = dataset[:, 1:], dataset[:, 0]\nX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n\n\nprint(\"########################################################################################\\n\"\n      \"#      DL8.5 classifier : user specific error function based on transactions ids       #\\n\"\n      \"########################################################################################\")\n\n\n# return the error and the majority class\ndef error(tids, y):\n    classes, supports = np.unique(y.take(list(tids)), return_counts=True)\n    maxindex = np.argmax(supports)\n    return sum(supports) - supports[maxindex], classes[maxindex]\n\n\nclf = DL85Classifier(max_depth=2, error_function=lambda tids: error(tids, y_train), time_limit=600)\nstart = time.perf_counter()\nprint(\"Model building...\")\nclf.fit(X_train, y_train)\nduration = time.perf_counter() - start\nprint(\"Model built. Duration of building =\", round(duration, 4))\ny_pred = clf.predict(X_test)\nprint(\"Confusion Matrix below\")\nprint(confusion_matrix(y_test, y_pred))\nprint(\"Accuracy DL8.5 on training set =\", round(clf.accuracy_, 4))\nprint(\"Accuracy DL8.5 on test set =\", round(accuracy_score(y_test, y_pred), 4))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}