# -*- coding: utf-8 -*-
"""
Created on Sun Oct 23 20:35:55 2016

Support Vector Machine (SVM) of Iris data using scikit-learn
the Machine Learning toolkit in Python
@author: Márton
"""

from sklearn import datasets as ds;
from sklearn import svm;
#from sklearn import decomposition as decomp;
#from sklearn import cross_validation as cv;
import numpy as np;
import matplotlib.pyplot as plt;
import matplotlib.colors as col;
import copy;

# Global parameters
res = 1000;  # resolution of functions, curves etc. in figures
 
# load dataset and partition in training and testing sets
iris = ds.load_iris();  # load iris dataset
# Scatterplot for two input attributes
x_axis = 0;  # x axis attribute (0,1,2,3)
y_axis = 3;  # y axis attribute (0,1,2,3)
colors = ['red','green','blue']; # colors for target values setosa 
fig = plt.figure(1);
plt.title('Scatterplot for iris dataset');
plt.xlabel(iris.feature_names[x_axis]);
plt.ylabel(iris.feature_names[y_axis]);
plt.scatter(iris.data[:,x_axis],iris.data[:,y_axis],s=50,c=iris.target,cmap=col.ListedColormap(colors));
plt.show();

# Joining the Versicolor and Virginica classes
iris2 = copy.deepcopy(iris);
N = np.size(iris2.target);
for ind in range(N):
    if iris2.target[ind] == 2:
        iris2.target[ind] = 1;
    #nedif
#endfor        
iris2.target_names = iris2.target_names[0:2];
iris2.target_names[1] = iris2.target_names[1] + ' and virginica';        

# Scatterplot for two input attributes
x_axis = 0;  # x axis attribute (0,1,2,3)
y_axis = 3;  # y axis attribute (0,1,2,3)
colors = ['red','blue']; # colors for target values setosa 
fig = plt.figure(1);
plt.title('Scatterplot for iris dataset');
plt.xlabel(iris.feature_names[x_axis]);
plt.ylabel(iris.feature_names[y_axis]);
plt.scatter(iris.data[:,x_axis],iris.data[:,y_axis],s=50,c=iris.target,cmap=col.ListedColormap(colors));
plt.show();  

# Support vector classification for two input attributes
x_axis = 0;  # x axis attribute (0,1,2,3)
y_axis = 3;  # y axis attribute (0,1,2,3)
data2 = np.zeros((N,2));
data2[:,0] = iris2.data[:,x_axis];
data2[:,1] = iris2.data[:,y_axis];
iris_svm = svm.LinearSVC(C=0.1);
iris_svm.fit(data2,iris2.target);   

# Scatterplot for two input attributes

colors = ['red','blue']; # colors for two target values 
left = min(iris2.data[:,x_axis])-0.5;
right = max(iris2.data[:,x_axis])-0.5;
sign = iris_svm.coef_[0][0]*data2[:,0] + iris_svm.coef_[0][1]*data2[:,1] + iris_svm.intercept_;
plane = -(iris_svm.coef_[0][0]*np.linspace(left,right,res) + iris_svm.intercept_)/iris_svm.coef_[0][1];
fig = plt.figure(2);
plt.title('Scatterplot for iris dataset');
plt.xlabel(iris.feature_names[x_axis]);
plt.ylabel(iris.feature_names[y_axis]);
plt.scatter(iris.data[:,x_axis],iris.data[:,y_axis],s=50,c=iris.target,cmap=col.ListedColormap(colors));
plt.scatter(np.linspace(left,right,res),plane,s=10,color='black');
plt.show();   