# -*- coding: utf-8 -*-
"""
Created on Wed Nov 21 08:39:55 2018

@author: Márton
"""

from sklearn.datasets import fetch_20newsgroups;
import sklearn.feature_extraction.text as txt;  # importing text preprocessing
import sklearn.cluster as cluster;
import sklearn.metrics as metrics;
from scipy.cluster.hierarchy import dendrogram
import sklearn.mixture as mix;
import sklearn.utils.random as rd;
import matplotlib.pyplot as plt;
import numpy as np;

# Importing the training and testing datasets
categories = [
    'alt.atheism',
    'sci.space',
    'rec.autos'
];
ds_train = fetch_20newsgroups(subset='train',
                             shuffle=True, categories=categories, random_state=2018);
ds_test = fetch_20newsgroups(subset='test',
                             shuffle=True, categories=categories, random_state=2018);
n_train = len(ds_train.data);
n_test = len(ds_test.data);
n_class = len(ds_train.target_names);

# Vectorization of the docs
min_df = 0.01;
max_df = 0.8;
vectorizer = txt.TfidfVectorizer(stop_words='english',
                                 min_df=min_df,max_df=max_df); 
DT_train = vectorizer.fit_transform(ds_train.data);
vocabulary_list = vectorizer.get_feature_names();
vocabulary = np.asarray(vocabulary_list);  # vocabulary in 1D array 
n_words = DT_train.shape[1];
DT_test = vectorizer.transform(ds_test.data);
samp_ind1 = rd.sample_without_replacement(n_population=n_train, n_samples=100);
samp_ind2 = rd.sample_without_replacement(n_population=n_test, n_samples=100);
DT_train_sample = DT_train[samp_ind1,:];
DT_test_sample = DT_test[samp_ind2,:];

# document-term matrix in dense form 
doc_term_train = DT_train.toarray();

max_cluster = 30;
sse_train = [];
sse_test = [];
for i in range(max_cluster):
    n_clus = i+2;
    kmeans = cluster.KMeans(n_clusters=n_clus, n_init=3, max_iter=10, random_state=2019);
    kmeans.fit(DT_train_sample);
#    doc_cluster_train = kmeans.predict(DT_train);
    sse_train.append(kmeans.inertia_);
    kmeans.fit(DT_test_sample);
    sse_test.append(kmeans.inertia_);

fig = plt.figure(1);
plt.title('Diagnostic for KMeans method');
plt.xlabel('Number of clusters');
plt.ylabel('Inertia');
plt.plot(sse_train,c='blue');
plt.plot(sse_test,c='red');
plt.show();        

n_cl = 3;
kmeans = cluster.KMeans(n_clusters=n_clus, n_init=3, max_iter=10, random_state=2019);
kmeans.fit(doc_term_train);
doc_cluster_train = kmeans.predict(doc_term_train);
cm_kmeans = metrics.cluster.contingency_matrix(ds_train.target,doc_cluster_train);

n_cl = 5;
clustering = cluster.AgglomerativeClustering(n_clusters=n_cl,affinity='euclidean',linkage='ward');
clustering.fit(doc_term_train);
complete_label = clustering.labels_;

rand = metrics.adjusted_rand_score(ds_train.target,complete_label);
hom = metrics.homogeneity_score(ds_train.target,complete_label);
comp = metrics.completeness_score(ds_train.target,complete_label);
cont_mat = metrics.cluster.contingency_matrix(ds_train.target,complete_label);
    
n_clus = 3;
gmm = mix.GaussianMixture(n_components=n_clus);
gmm.fit(doc_term_train);
doc_cluster_train = gmm.predict(doc_term_train);
doc_prob_train = gmm.predict_proba(doc_term_train);
cm = metrics.cluster.contingency_matrix(ds_train.target,doc_cluster_train);

def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([model.children_, model.distances_,
                                      counts]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)


# setting distance_threshold=0 ensures we compute the full tree.
model = cluster.AgglomerativeClustering(distance_threshold=0, n_clusters=None)

model = model.fit(doc_term_train)
plt.title('Hierarchical Clustering Dendrogram')
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode='level', p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()
