In [13]:
# generate training data
import numpy as np
x = np.arange(10)
w0 = 3.4
w1 = 0.45
y = w0 + w1*x + 0.3*np.random.randn(10)

import matplotlib.pyplot as plt
plt.plot(x,y,'o')
plt.show()
In [14]:
# linear regression in one-shot
A = np.vstack([x, np.ones(len(x))]).T
w1_, w0_ = np.linalg.lstsq(A, y, rcond=None)[0]

import matplotlib.pyplot as plt
plt.plot(x,y,'o')
plt.plot(x, w1_*x + w0_, 'r')
plt.show()
In [15]:
t=0
T=100
w0__=np.zeros(T)
w1__=np.zeros(T)
w0__[0]=5
w1__[0]=-1
eta=0.6
In [25]:
# linear regression step-by-step
from numpy.linalg import inv
H = inv(np.matmul(A.T,A))
res = A.dot(np.append(w1__[t],w0__[t]))-y
grd = 0.5*(np.matmul(H,np.matmul(A.T,res)))

w0__[t+1] = w0__[t] - eta*grd[1]
w1__[t+1] = w1__[t] - eta*grd[0]
t=t+1

import matplotlib.pyplot as plt
plt.plot(x,y,'o')
plt.plot(x, w1__[t]*x + w0__[t], 'r')
plt.show()

import matplotlib
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
delta = 0.025
x_ = np.arange(-1, 2.0, delta)
y_ = np.arange(3.0, 5.0, delta)
X, Y = np.meshgrid(x_, y_)
Z = np.zeros(X.shape)
for i in range(10):
    Z = Z+(A[i,0]*X + A[i,1]*Y-y[i])**2
fig, ax = plt.subplots()
CS = ax.contour(X, Y, Z)
ax.clabel(CS, inline=1, fontsize=10)
ax.plot(w1__[0:t],w0__[0:t],'-o')
Out[25]:
[<matplotlib.lines.Line2D at 0x120015210>]

This linear regression minimizes RSS$=\sum_{i=1}^N (\hat{y}_i-y_i)^2$

Now, we introduce an outlier and see how sensitive RSS is

In [29]:
x_ = np.append(x,-4)
y_ = np.append(y,13)
#y_ = np.append(y,7)

A = np.vstack([x_, np.ones(len(x_))]).T
w1_, w0_ = np.linalg.lstsq(A, y_, rcond=None)[0]

import matplotlib.pyplot as plt
plt.plot(x_,y_,'o')
plt.plot(x_, w1_*x_ + w0_, 'r')
plt.show()
In [201]:
 
In [202]:
 
Out[202]:
array([[0.   , 0.025, 0.05 , ..., 7.925, 7.95 , 7.975],
       [0.   , 0.025, 0.05 , ..., 7.925, 7.95 , 7.975],
       [0.   , 0.025, 0.05 , ..., 7.925, 7.95 , 7.975],
       ...,
       [0.   , 0.025, 0.05 , ..., 7.925, 7.95 , 7.975],
       [0.   , 0.025, 0.05 , ..., 7.925, 7.95 , 7.975],
       [0.   , 0.025, 0.05 , ..., 7.925, 7.95 , 7.975]])
In [ ]: