# -*- coding: utf-8 -*-
"""
Created on Mon Oct  1 21:34:39 2018
Vizualization of 20newsgroups dataset
Count vectorization of 20newsgroups dataset
@author: Márton
"""

from sklearn.datasets import fetch_20newsgroups;
import sklearn.feature_extraction.text as txt;  # importing text preprocessing
import sklearn.metrics.pairwise as pw;
from sklearn import decomposition as decomp;
import matplotlib.pyplot as plt;  # importing pyplot
import matplotlib.colors as col;  # importing coloring
import pandas as pd; # importing pandas
import numpy as np;  # importing numpy for arrays
import scipy as sp; 

# Importing the training and testing datasets
ds_train = fetch_20newsgroups(subset='train',
                             shuffle=True, random_state=2020);
ds_test = fetch_20newsgroups(subset='test',
                             shuffle=True, random_state=2020);
n_train = len(ds_train.data);
n_test = len(ds_test.data);
n_class = len(ds_train.target_names);

# Computing the frequencies in numpy for training dataset                            
unique, counts = np.unique(ds_train.target, return_counts=True);                            

# Computing the frequencies in pandas for training dataset                             
target_freq_train = pd.crosstab(index=ds_train.target,  # Making a crosstab
                              columns="count");      # Name the count column                             

# Drawing horizontal barplot for freqs of cats in numpy                                
fig = plt.figure(1);
plt.title('Frequency of topics in training dataset');
plt.xlabel('Frequency');
plt.ylabel('Topics');
plt.barh(ds_train.target_names ,counts, align='center', color='blue');
plt.show(); 

# Drawing horizontal barplot for freqs of cats in pandas
fig = plt.figure(2);
plt.title('Frequency of topics in training dataset');
plt.xlabel('Frequency');
plt.ylabel('Topics');
plt.barh(ds_train.target_names,target_freq_train['count'], align='center');
plt.show();  

# Computing the frequencies in pandas for test dataset                             
target_freq_test = pd.crosstab(index=ds_test.target,  # Making a crosstab
                              columns="count");      # Name the count column     

# Drawing horizontal barplot for freqs of cats in pandas
fig = plt.figure(2);
plt.title('Frequency of topics in test dataset');
plt.xlabel('Frequency');
plt.ylabel('Topics');
plt.barh(ds_test.target_names,target_freq_test['count'], align='center');
plt.show();  


min_pr = 0.2;        
vectorizer = txt.CountVectorizer(stop_words='english',min_df=min_pr); 
DT_train = vectorizer.fit_transform(ds_train.data); 
vocabulary_dict = vectorizer.vocabulary_;
vocabulary_list = vectorizer.get_feature_names();
vocabulary = np.asarray(vocabulary_list);  # vocabulary in 1D array
stopwords = vectorizer.stop_words_;
n_words = DT_train.shape[1];

# Transforming document-term matrix to dense form
doc_term_train = DT_train.toarray();
# doc_term_train = DT_train.todense().getA();  an other way

# stopword in list
stopwords_list = list(stopwords);

# visualisation the frequencies of keywords
keywords_freq = np.transpose(np.asarray(np.sum(DT_train,axis=0)))[:,0];
fig = plt.figure(3);
plt.title('Frequency of keywords in training dataset');
plt.xlabel('Frequency');
plt.ylabel('Words');
plt.barh(vocabulary,keywords_freq, align='center', color='blue');
plt.show(); 

# visualisation the histogram of keywords freq by documents
docs_freq = np.sum(DT_train,axis=1);
fig = plt.figure(4);
plt.title('Histogram of keywords occurances');
plt.xlabel('Number of occurances of keywords in a document');
plt.ylabel('Frequency');
count, bins, ignored  = plt.hist(docs_freq,100,alpha=0.75);
plt.show();

# the first k most frequent keywords by documents
first_k_words = 3;
doc_freq_words = np.chararray((n_train,first_k_words));
doc_ind = sp.argsort(doc_term_train,axis=1);
doc_ind = np.flip(doc_ind,axis=1);
doc_freq_words = vocabulary[doc_ind[:,0:first_k_words]];

# the first k most frequent posts which contains a given keyword
first_k_docs = 20;
word_freq_docs = np.chararray((first_k_docs,n_words));
word_ind = np.argsort(doc_term_train,axis=0);
word_ind = np.flip(word_ind,axis=0);
wind = 8;  # index of word
word_freq_docs = list();
for i in range(first_k_docs):
        word_freq_docs.append(ds_train.data[word_ind[i,wind]]);

explained_variance_ratio_treshold = 0.7;
svd = decomp.TruncatedSVD(n_components=n_words-1);
svd.fit(DT_train);
cum_explained_variance_ratio_ = np.cumsum(svd.explained_variance_ratio_);
dim = np.sum([cum_explained_variance_ratio_<explained_variance_ratio_treshold][0].astype(int));

svd = decomp.TruncatedSVD(n_components=dim);
TD_svd = svd.fit_transform(DT_train);

class_mean = np.zeros((n_class,dim));
for i in range(n_class):
    class_ind = [ds_train.target==i][0].astype(int);
    class_mean[i,:] = np.average(TD_svd, axis=0, weights=class_ind);

colors = ['blue','red'];     

fig = plt.figure(5);
plt.title('Dimension reduction');
plt.xlabel('Dim1');
plt.ylabel('Dim2');
plt.scatter(TD_svd[:,0],TD_svd[:,1],s=50,c=ds_train.target);
plt.show();     

# transforming the test dataset            
DT_test = vectorizer.transform(ds_test.data);



cos_sim = pw.cosine_similarity(doc_term_train,doc_term_test)

       

fig = plt.figure(2);
x = 1;
y = 5;
plt.title('Documents in the space of words');
plt.xlabel(vocabular[x]);
plt.ylabel(vocabular[y]);
plt.scatter(TD_train_dense[:,x],TD_train_dense[:,y],s=50, c=ds_train.target,
            cmap=col.ListedColormap(colors));
plt.show();       

TD_test = vectorizer.transform(ds_test.data); 
                                 
