# -*- coding: utf-8 -*-
"""
Created on Mon Mar 23 13:18:03 2020

Principal Component (PC) analysis of digits dataset

@author: Márton Ispány
"""

from sklearn import datasets as ds; # importing datasets
from sklearn import model_selection as ms; # importing model selection tools
from sklearn import decomposition as decomp;  # importing dimension reduction 
import numpy as np;  # importing numpy the scientific computing package with Python
from matplotlib import pyplot as plt;  # importing the MATLAB-like plotting tool


# loading dataset and computing dimensions
digits = ds.load_digits();
n = digits.data.shape[0]; # number of records
p = digits.data.shape[1]; # number of attributes

# Visualizing digit images 
image_ind = 10;  #  index of the image
plt.matshow(15-digits.images[image_ind]);
plt.show();

# Partitioning into training and testing sets
X_train, X_test, y_train, y_test = ms.train_test_split(digits.data, 
             digits.target, test_size=0.3, random_state=2020);

# Full PCA on training set
pca = decomp.PCA();
pca.fit(X_train);

# Visualizing the variance ratio which measures the importance of principal components
fig = plt.figure(2);
plt.title('Explained variance ratio plot');
var_ratio = pca.explained_variance_ratio_;
x_pos = np.arange(len(var_ratio))+1;
plt.xlabel('Principal Components');
plt.ylabel('Variance');
plt.bar(x_pos,var_ratio, align='center', alpha=0.5);
plt.show(); 

# Visualizing the cumulative ratio which measures the impact of first n PCs
fig = plt.figure(3);
plt.title('Cumulative explained variance ratio plot');
cum_var_ratio = np.cumsum(var_ratio);
x_pos = np.arange(len(cum_var_ratio))+1;
plt.xlabel('Principal Components');
plt.ylabel('Variance');
plt.bar(x_pos,cum_var_ratio, align='center', alpha=0.5);
plt.show(); 

# Visualizing the training set in 2D PC space by using colors for different digits
PC_train = pca.transform(X_train);
fig = plt.figure(4);
plt.title('Scatterplot for training digits dataset');
plt.xlabel('PC1');
plt.ylabel('PC2');
plt.scatter(PC_train[:,0],PC_train[:,1],s=50,c=y_train,cmap = 'tab10');
plt.show();

# Visualizing the test set in 2D PC space
PC_test = pca.transform(X_test);
fig = plt.figure(5);
plt.title('Scatterplot for test digits dataset');
plt.xlabel('PC1');
plt.ylabel('PC2');
plt.scatter(PC_test[:,0],PC_test[:,1],s=50,c=y_test,cmap = 'tab10');
plt.show();

# Compare the last two figures! 