# -*- coding: utf-8 -*-
"""
Created on Wed Nov 21 08:39:55 2018

@author: Márton
"""

import numpy as np;
from sklearn.datasets import fetch_20newsgroups;
from sklearn.feature_extraction.text import TfidfVectorizer;  # importing text preprocessing
from sklearn.cluster import KMeans;
import sklearn.metrics as metrics;
from scipy.cluster.hierarchy import dendrogram
import sklearn.mixture as mix;
import sklearn.utils.random as rd;
import matplotlib.pyplot as plt;


# Importing the training and testing datasets
categories = [
    'rec.sport.baseball',
    'rec.sport.hockey',
    'sci.electronics',
    'sci.med',
    'sci.space',
];
ds_train = fetch_20newsgroups(subset='train',
                             shuffle=True, categories=categories, random_state=2018);
ds_test = fetch_20newsgroups(subset='test',
                             shuffle=True, categories=categories, random_state=2018);
n_train = len(ds_train.data);
n_test = len(ds_test.data);
n_class = len(ds_train.target_names);

# Vectorization of the docs
min_df = 0.05;
max_df = 0.8;
vectorizer = TfidfVectorizer(stop_words='english',
                                 min_df=min_df,max_df=max_df); 
DT_train = vectorizer.fit_transform(ds_train.data);
vocabulary_list = vectorizer.get_feature_names();
vocabulary = np.asarray(vocabulary_list);  # vocabulary in 1D array 
n_words = DT_train.shape[1];
DT_test = vectorizer.transform(ds_test.data);
samp_ind1 = rd.sample_without_replacement(n_population=n_train, n_samples=100);
samp_ind2 = rd.sample_without_replacement(n_population=n_test, n_samples=100);
DT_train_sample = DT_train[samp_ind1,:];
DT_test_sample = DT_test[samp_ind2,:];

# document-term matrix in dense form 
doc_term_train = DT_train.toarray();

# Default parameters
n_c = 2; # number of clusters

# Enter parameters from consol
user_input = input('Number of clusters [default:2]: ');
if len(user_input) != 0 :
    n_c = np.int8(user_input);
    
# Kmeans clustering
kmeans = KMeans(n_clusters=n_c, n_init=3, max_iter=10, random_state=2020);
kmeans.fit(DT_train);
clabels_train = kmeans.predict(DT_train);
clabels_train1 = kmeans.labels_;
clabels_test = kmeans.predict(DT_test);

cm_train = metrics.cluster.contingency_matrix(ds_train.target,clabels_train);

max_cluster = 30;
DB_train = [];
DB_test = [];
for i in range(max_cluster):
    n_clus = i+2;
    kmeans = cluster.KMeans(n_clusters=n_clus, n_init=3, max_iter=10, random_state=2019);
    kmeans.fit(DT_train_sample);
#    doc_cluster_train = kmeans.predict(DT_train);
    clabels_train = kmeans.labels_;
    DB_train.append(metrics.davies_bouldin_score(DT_train_sample.toarray(),clabels_train));
    clabels_test = kmeans.predict(DT_test_sample);
    DB_test.append(metrics.davies_bouldin_score(DT_test_sample.toarray(),clabels_test));

fig = plt.figure(1);
plt.title('Diagnostic for KMeans method');
plt.xlabel('Number of clusters');
plt.ylabel('DB score');
plt.plot(DB_train,c='blue');
plt.plot(DB_test,c='red');
plt.show();        

n_cl = 5;
clustering = cluster.AgglomerativeClustering(n_clusters=n_cl,affinity='euclidean',linkage='ward');
clustering.fit(doc_term_train);
complete_label = clustering.labels_;

rand = metrics.adjusted_rand_score(ds_train.target,complete_label);
hom = metrics.homogeneity_score(ds_train.target,complete_label);
comp = metrics.completeness_score(ds_train.target,complete_label);
cont_mat = metrics.cluster.contingency_matrix(ds_train.target,complete_label);
    
n_clus = 3;
gmm = mix.GaussianMixture(n_components=n_clus);
gmm.fit(doc_term_train);
doc_cluster_train = gmm.predict(doc_term_train);
doc_prob_train = gmm.predict_proba(doc_term_train);

cm_train_gmm = metrics.cluster.contingency_matrix(ds_train.target,doc_cluster_train);

def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([model.children_, model.distances_,
                                      counts]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)


# setting distance_threshold=0 ensures we compute the full tree.
model = cluster.AgglomerativeClustering(distance_threshold=0, n_clusters=None)

model = model.fit(doc_term_train)
plt.title('Hierarchical Clustering Dendrogram')
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode='level', p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()
