# -*- coding: utf-8 -*-
"""
Created on Thu Feb 25 09:26:58 2021

Task: Fitting logistic regression for Iris data  
Results: 2D plots of Iris data and logistic regression models for the
following scenarios: 
i) setosa against versicolor or virginica
ii) virginica against setosa or versicolor

Python tools    
Libraries: numpy, matplotlib, sklearn
Modules: pyplot, colors, datasets, linear_model
Classes: LogisticRegression
Functions:

@author: Márton Ispány
"""

import numpy as np;  # importing numerical library
import matplotlib.pyplot as plt;  # importing MATLAB-like plotting framework
import matplotlib.colors as col;  # importing coloring tools from MatPlotLib
from sklearn.datasets import load_iris; # importing Iris dataset
from sklearn.linear_model import LogisticRegression; # Class for logistic regression

# loading dataset
iris = load_iris();
n = iris.data.shape[0]; # number of records
p = iris.data.shape[1]; # number of attributes
k = iris.target_names.shape[0]; # number of target classes

# Printing the basic parameters
print(f'Number of records:{n}');
print(f'Number of attributes:{p}');
print(f'Number of target classes:{k}');

# Scatterplot for two input attributes
res = 0.001;  #  resolution of the graph
# Default axis
x_axis = 0;  # x axis attribute (0,1,2,3)
y_axis = 1;  # y axis attribute (0,1,2,3)
# Enter axis from consol
user_input = input('X axis [0..3, default:0]: ');
if len(user_input) != 0 and np.int8(user_input)>=0 and np.int8(user_input)<=3 :
    x_axis = np.int8(user_input);
user_input = input('Y axis [0..3, default:1]: ');
if len(user_input) != 0 and np.int8(user_input)>=0 and np.int8(user_input)<=3 :
    y_axis = np.int8(user_input);    
colors = ['blue','red','green']; # colors for target values: setosa blue, versicolor red, virginica green
fig = plt.figure(1);
plt.title('Scatterplot for Iris dataset');
plt.xlabel(iris.feature_names[x_axis]);
plt.ylabel(iris.feature_names[y_axis]);
plt.scatter(iris.data[:,x_axis],iris.data[:,y_axis],s=50,c=iris.target,cmap=col.ListedColormap(colors));
plt.show();

# Run one of the two binomial transformation: multilabel -> binomial label
# Binomial transformation of target 0 - setosa, 1 - virginica or versicolor
y = (np.sign(iris.target-0.5)+1)/2;
target_names = [iris.target_names[0]+'',iris.target_names[1]+' or '+iris.target_names[2]];

# Binomial transformation of target 0 - setosa or versicolor, 1 - virginica
y = (np.sign(iris.target-1.5)+1)/2;
target_names = [iris.target_names[0]+' or '+iris.target_names[1],iris.target_names[2]+''];

# Fitting logistic regression for two input attributes
X = iris.data[:,(x_axis,y_axis)];
logreg = LogisticRegression();
logreg.fit(X,y);  #  fitting the model to data
intercept = logreg.intercept_;  #  estimated intercept
coefficients = logreg.coef_;  #  estimated slope
accuracy = logreg.score(X,y);  #  accuracy for model fitting
target_pred_logreg = logreg.predict(X);  #  prediction of the target

# Visualizing in 2D
xmin = min(X[:,0])-0.1;
xmax = max(X[:,0])+0.1;
base = np.arange(xmin,xmax,res);
fig = plt.figure(2);
plt.title('Scatterplot for Iris dataset with separating line');
plt.xlabel(iris.feature_names[x_axis]);
plt.ylabel(iris.feature_names[y_axis]);
plt.scatter(iris.data[:,x_axis],iris.data[:,y_axis],s=50,c=iris.target,cmap=col.ListedColormap(colors));
value = -(intercept[0]+coefficients[0,0]*base)/coefficients[0,1];
plt.scatter(base,value,s=5,color="black");
plt.show();    

# Fitting logistic regression for all input variables
logreg = LogisticRegression(solver='liblinear');  # instance of the LogisticRegression class
logreg.fit(iris.data,y);  #  fitting the model to data
intercept = logreg.intercept_;  #  estimated intercept
coefficients = logreg.coef_;  #  estimated slope
accuracy = logreg.score(iris.data,iris.target);  #  accuracy for model fitting
target_pred_logreg = logreg.predict(iris.data);  #  prediction of the target
p_pred_logreg = logreg.predict_proba(iris.data);  # posterior distribution for the target


xmin = min(iris.data[:,x_axis])-0.1;
xmax = max(iris.data[:,x_axis])+0.1;
base = np.arange(xmin,xmax,res);
fig = plt.figure(3);
plt.title('Scatterplot for Iris dataset with separating line');
plt.xlabel(iris.feature_names[x_axis]);
plt.ylabel(iris.feature_names[y_axis]);
plt.scatter(iris.data[:,x_axis],iris.data[:,y_axis],s=50,c=iris.target,cmap=col.ListedColormap(colors));
value = -(intercept[0]+coefficients[0,x_axis]*base)/coefficients[0,y_axis];
plt.scatter(base,value,s=5,color="black");
plt.show();

