# -*- coding: utf-8 -*-
"""
Created on Wed Nov  4 13:05:19 2020

@author: Márton
"""

from sklearn.datasets import load_digits; # importing datasets
from sklearn.tree import DecisionTreeClassifier, plot_tree;  # importing decision tree classifier
from sklearn.model_selection import train_test_split;
from sklearn.metrics import confusion_matrix;
from matplotlib import pyplot as plt;  # importing MATLAB-like plotting framework


# loading dataset and computing dimensions
digits = load_digits();
n = digits.data.shape[0]; # number of records
p = digits.data.shape[1]; # number of attributes

X_train, X_test, y_train, y_test = train_test_split(digits.data, 
                digits.target, test_size=0.2, random_state=2020)


# Initialize our decision tree object
crit = 'entropy';
depth =10;
# Instance of decision tree class
class_tree = DecisionTreeClassifier(criterion=crit,max_depth=depth);

# Fitting decision tree (tree induction + pruning)
class_tree.fit(X_train, y_train);
accuracy_entropy_train = class_tree.score(X_train, y_train); # Goodness of tree
accuracy_entropy_test = class_tree.score(X_test, y_test);

y_test_pred = class_tree.predict(X_test);
cm_test = confusion_matrix(y_test, y_test_pred);

# Visualizing decision tree
fig = plt.figure(1,figsize = (12,6),dpi=100);
plot_tree(class_tree, 
               class_names = str(digits.target_names),
               filled = True, fontsize = 8);
fig.savefig('digits_tree_entropy.png'); # Writing to local repository as C:\\Users\user_name 



