# -*- coding: utf-8 -*-
"""
Created on Wed Oct  9 09:07:49 2019

@author: Márton
"""

from sklearn.datasets import fetch_20newsgroups;
import sklearn.feature_extraction.text as txt;
import sklearn.metrics.pairwise as pw;
from sklearn import decomposition as dc;
import matplotlib.pyplot as plt;
import matplotlib.colors as col;
import pandas as pd;
import numpy as np;
from numpy import linalg as la;
from mpl_toolkits.mplot3d import Axes3D;

categories = [
    'alt.atheism',
    'talk.religion.misc',
];
ds_train = fetch_20newsgroups(subset='train',categories=categories);
ds_test = fetch_20newsgroups(subset='test',categories=categories);
n_train = len(ds_train.data);
n_test = len(ds_test.data);
        
        
vectorizer = txt.TfidfVectorizer(stop_words='english',max_df=0.8,min_df=0.2); 
DT_train = vectorizer.fit_transform(ds_train.data); 
vocabulary_dict = vectorizer.vocabulary_;
vocabulary_list = vectorizer.get_feature_names();
vocabulary = np.asarray(vocabulary_list);  # vocabulary in 1D array
stopwords = vectorizer.stop_words_;
n_words = DT_train.shape[1];

# document-term matrix in dense form 
doc_term_train = DT_train.todense().getA();

# stopword in list
stopwords_list = list(stopwords);

# the first k most frequent keywords by documents
first_k_words = 3;
doc_freq_words = np.chararray((n_train,first_k_words));
doc_ind = np.argsort(doc_term_train,axis=1);
doc_ind = np.flip(doc_ind,axis=1);
doc_freq_words = vocabulary[doc_ind[:,0:first_k_words]];

# Full PCA using scikit-learn
pca = dc.PCA();
pca.fit(doc_term_train);
fig = plt.figure(1);
plt.title('Explained variance ratio plot');
var_ratio = pca.explained_variance_ratio_;
x_pos = np.arange(len(var_ratio));
plt.xticks(x_pos,x_pos+1);
plt.xlabel('Principal Components');
plt.ylabel('Variance');
plt.bar(x_pos,var_ratio, align='center', alpha=0.5);
plt.show(); 

pca = dc.PCA(n_components=3);
pca.fit(doc_term_train);
doc_pc = pca.transform(doc_term_train);

fig = plt.figure(2);
ax = fig.add_subplot(111, projection='3d');

ax.scatter(doc_pc[:,0], doc_pc[:,1], doc_pc[:,2],c=ds_train.target);

ax.set_xlabel('PC1');
ax.set_ylabel('PC2');
ax.set_zlabel('PC3');

plt.show();
