# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 12:43:23 2019

@author: Márton
"""

import numpy as np; 
import matplotlib.pyplot as plt;
from sklearn import datasets as ds;
from sklearn import model_selection as ms;
from sklearn import tree, naive_bayes, metrics;
import itertools;

# load dataset and partition in training and testing sets
digits = ds.load_digits();
n = digits.data.shape[0];
p = digits.data.shape[1];

# Particionálás tanító és teszt adatállományra
X_train, X_test, y_train, y_test = ms.train_test_split(digits.data, 
             digits.target, test_size=0.3, random_state=2019);
                                                       
# Decision tree
# Initialize our decision tree object
crit = 'entropy';
depth = 20;
classification_tree = tree.DecisionTreeClassifier(criterion=crit,max_depth=depth);

# Train our decision tree (tree induction + pruning)
classification_tree = classification_tree.fit(X_train, y_train);
pred_tree_test = classification_tree.predict(X_test);
pred_tree_train = classification_tree.predict(X_train);

# Computing confusion matrix for decision tree
cm_tree_test = metrics.confusion_matrix(y_test, pred_tree_test);
cm_tree_train = metrics.confusion_matrix(y_train, pred_tree_train);
acc_tree_test = metrics.accuracy_score(y_test, pred_tree_test); 
acc_tree_train = metrics.accuracy_score(y_train, pred_tree_train); 
                                                      

# Visualisation of the confusion matrix

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Greens):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    
# Plot non-normalized confusion matrix
plt.figure(1);
plot_confusion_matrix(cm_tree_test, classes=digits.target_names,
                      title='Confusion matrix for test dataset (decision tree classifier)');
plt.show();

plt.figure(2);
plot_confusion_matrix(cm_tree_train, classes=digits.target_names,
                      title='Confusion matrix for train dataset (decision tree classifier)');
plt.show();