# -*- coding: utf-8 -*-
"""
Spyder Editor
Created on Thu Nov 02 16:30:55 2017

XOR problem for SVM classification
@author: Márton Ispány
"""

import numpy as np;  # importing numpy the scientific computing package with Python
import matplotlib.pyplot as plt; # importing pyplot the MATLAB-like plotting framework from matplotlib library
import matplotlib.colors as col; # importing coloring package
from sklearn import svm;  #  importing svm package from scikit-learn library

# Global parameters
res = 500;  # resolution of functions, curves etc. in figures

xor = np.array([[-1,-1,-1],[-1,1,1],[1,-1,1],[1,1,-1]]); # dataset
xor_svm = svm.SVC(C=1.0,kernel='poly',degree=2);  # polynomial kernel with degree 2
xor_svm.fit(xor[:,0:2],xor[:,2]);    # fitting SVM
xor_pred = xor_svm.predict(xor[:,0:2]);  # predicting the dataset

grid = np.zeros((res*res,2));  # grid array
grid[:,0] = np.kron(np.arange(res),np.ones((res)));
grid[:,1] = np.kron(np.ones((res)),np.arange(res));  
grid = 2.2*grid/res-1.1;  # computing grid values on [-1.1,1.1]#2
    
grid_pred = xor_svm.predict(grid); # predicting on the grid
color1 = ['red','blue']; # colors for original target
color2 = ['lightpink','lightskyblue']; # colors for the predicted target on the grid
fig = plt.figure();
plt.scatter(grid[:,0],grid[:,1],c=grid_pred,cmap=col.ListedColormap(color2));
# scatterplot on the grid
plt.scatter(xor[:,0],xor[:,1],s=50,c=xor[:,2],cmap=col.ListedColormap(color1));
# scatterplot for the original data
plt.show();   