# -*- coding: utf-8 -*-
"""
Created on Fri Oct 27 22:11:51 2017

Linear basis functions with sci-kit learn
Author: Márton Ispány
"""

import numpy as np;
import matplotlib.pyplot as plt;
from sklearn import linear_model;

# Global parameters
N =50; # number of training points
sig = 0.1; # standard error 
B = 20; # number of basis functions
res = 1000;  # resolution of functions, curves etc. in figures
end = 8*np.pi;  # endpoint of the interval

x = np.linspace(0,end,N);  # input points 
f = np.sin(x);    # target values
f_err = f + np.random.normal(0,sig,N);   # training target values with errors
Phi = np.zeros((N,B));   # design matrix
for ind in range(B):
    Phi[:,ind] = x**ind;    # computing a column of design matrix by evaluating basis function in input points
#endfor
    
ols=linear_model.LinearRegression(fit_intercept=False);
ols.fit(Phi,f_err);
f_pred = ols.predict(Phi);

w = ols.coef_;
w0 = ols.intercept_;

x_new = 2;
Phi_new = np.zeros((1,B));   # design matrix
for ind in range(B):
    Phi_new[:,ind] = x_new**ind;    # computing a column of design matrix by evaluating basis function in input points
#endfor
fx_pred = ols.predict(Phi_new);

par = ols.get_params;

plt.figure(1);
plt.plot(np.linspace(0,end,res),np.sin(np.linspace(0,end,res)),color="green"); # sinus curve
plt.plot(x,f_pred,color="red");  # predicted points
plt.scatter(x,f_err,color="blue");  # training points
plt.show();  

Phi_res = np.zeros((res,B));   # design matrix
for ind in range(B):
    Phi_res[:,ind] = np.linspace(0,end,res)**ind; # computing a column of design matrix by evaluating basis function in input points
#endfor
   
res_pred = ols.predict(Phi_res);  
    
plt.figure(2);
plt.plot(np.linspace(0,end,res),np.sin(np.linspace(0,end,res)),color="green"); # sinus curve
plt.plot(np.linspace(0,end,res),res_pred,color="red");  # predicted points
plt.scatter(x,f_err,color="blue");  # training points
plt.show(); 

 