import numpy as np
import os.path
from time import time
import types
import matplotlib.pyplot as plt

import im_util
import interest_point

plt.rcParams['figure.figsize'] = (16.0, 10.0)
"""
Test of convolve_1d
"""
print('[ Test convolve_1d ]')
x = (np.random.rand(20)>0.8).astype(np.float32)
k = np.array([1,3,2])
print(x)
print(k)
y1 = im_util.convolve_1d(x, k)
y2 = np.convolve(x, k, 'same')
y3 = np.correlate(x, k, 'same')
print(y1)
print(y2)
print(y3)
print(' convolve error = ', np.sum((y1-y2)**2))
print(' correlate error = ', np.sum((y1-y3)**2))

"""
Test of convolve_image
"""
image_filename='data/test/100-0038_img.jpg'

print('[ Test convolve_image ]')
im = im_util.image_open(image_filename)
k = np.array([1,2,3,4,5,6,5,4,3,2,1])
#k = np.array([1,2,1])
print(' convolve_rows')
t0=time()
im1 = im_util.convolve_rows(im, k)
t1=time()
print(' % .2f secs' % (t1-t0))
print(' scipy convolve')
t0=time()
im2 = im_util.convolve(im, np.expand_dims(k,0))
t1=time()
print(' % .2f secs' % (t1-t0))
print(' convolve_image error =', np.sum((im1-im2)**2))

# optionally plot images for debugging
im1_norm=im_util.normalise_01(im1)
im2_norm=im_util.normalise_01(im2)
ax1,ax2=im_util.plot_two_images(im1_norm, im2_norm)

"""
Gaussian blurring test
"""
print('[ Test convolve_gaussian ]')

sigma=10.0
k=im_util.gauss_kernel(sigma)
print(' gauss kernel = ')
print(k)

t0=time()
im1 = im_util.convolve_gaussian(im, sigma)
t1=time()
print(' % .2f secs' % (t1-t0))

ax1,ax2=im_util.plot_two_images(im, im1)

"""
Gradient computation test
"""
print('[ Test gradient computation ]')
img = np.mean(im,2,keepdims=True)
Ix,Iy = im_util.compute_gradients(img)

# copy greyvalue to RGB channels
Ix_out = im_util.grey_to_rgb(im_util.normalise_01(Ix))
Iy_out = im_util.grey_to_rgb(im_util.normalise_01(Iy))

im_util.plot_two_images(Ix_out, Iy_out)

#i=np.array([[1,2],[3,46]])
#k=np.array([[1,-1]])
#a=im_util.convolve(i,k)
#print(i)
#print(k)
#print(a)

"""
Compute corner strength function
"""
print('[ Compute corner strength ]')
ip_ex = interest_point.InterestPointExtractor()
ip_fun = ip_ex.corner_function(img)

# normalise for display
[mn,mx]=np.percentile(ip_fun,[5,95])
small_val=1e-9
ip_fun_norm=(ip_fun-mn)/(mx-mn+small_val)
ip_fun_norm=np.maximum(np.minimum(ip_fun_norm,1.0),0.0)

"""
Find local maxima of corner strength
"""
print('[ Find local maxima ]')
row, col = ip_ex.find_local_maxima(ip_fun)
ip = np.stack((row,col))

ax1,ax2=im_util.plot_two_images(im_util.grey_to_rgb(ip_fun_norm),im)
interest_point.draw_interest_points_ax(ip, ax2)
print('done')
