# -*- coding: utf-8 -*-
"""
Created on Mon Sep 17 23:52:08 2018

@author: Márton
"""

from sklearn.datasets import fetch_20newsgroups;
from sklearn.feature_extraction.text import TfidfVectorizer;
from sklearn.decomposition import TruncatedSVD;
from sklearn.pipeline import Pipeline;
import matplotlib.pyplot as plt;
import pandas as pd;
import numpy as np;

ds_train = fetch_20newsgroups(subset='train',
                             shuffle=True, random_state=2018);
ds_test = fetch_20newsgroups(subset='test',
                             shuffle=True, random_state=2018);                              

unique, counts = np.unique(ds_train.target, return_counts=True);                            
                             
target_freq_train = pd.crosstab(index=ds_train.target,  # Make a crosstab
                              columns="count");      # Name the count column                             

fig = plt.figure(1);
plt.title('Frequency of topics in training dataset');
plt.xlabel('Topics');
plt.ylabel('Frequency');
plt.barh(ds_train.target_names ,counts, align='center');
plt.show(); 

fig = plt.figure(2);
plt.title('Frequency of topics in training dataset');
plt.xlabel('Topics');
plt.ylabel('Frequency');
plt.barh(ds_train.target_names,target_freq_train['count'], align='center');
plt.show(); 
                                
# raw documents to tf-idf matrix: 

vectorizer = TfidfVectorizer(stop_words='english', 
                             use_idf=True, 
                             smooth_idf=True);

TD_matrix = vectorizer.fit_transform(dataset.data);
# SVD to reduce dimensionality: 

svd_model = TruncatedSVD(n_components=100,
                         algorithm='randomized',
                         n_iter=10);

# pipeline of tf-idf + SVD, fit to and applied to documents:

svd_transformer = Pipeline([('tfidf', vectorizer), 
                            ('svd', svd_model)]);

svd_matrix = svd_transformer.fit_transform(dataset.data);

# vizualization

fig = plt.figure(1);
plt.title('SVD plot for 20newsgroups');
plt.xlabel('SVD1');
plt.ylabel('SVD2');
plt.scatter(svd_matrix[:,0],svd_matrix[:,1],s=50,c=dataset.target);
plt.show();