# Copyright 2017 Google Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path
import numpy as np
from time import time

import im_util
import interest_point
import geometry

class RANSAC:
  """
  Find 2-view consistent matches using RANSAC
  """
  def __init__(self):
    self.params={}
    self.params['num_iterations']=2000
    self.params['inlier_dist']=10
    self.params['min_sample_dist']=2

  def consistent(self, H, p1, p2):
    """
    Find interest points pairs that are consistent with 2D transform H
    Each point in p1 is matched to the point at the same index in p2.
    You have N points in total

    Inputs: H=transformation matrix (3,3)
            p1,p2=corresponding points in the two images of shape (2, N)

    Outputs: cons=list of inliers indicated by true/false (num_points)
    For each point return a boolean True or False whether is satisfies the
    transformation (the transformed point lies within "inlier_dist" parameter.

    Assumes that H maps from 1 to 2, i.e., hom(p2) ~= H hom(p1)
    hom(x) = homogeneous coordinates
    """

    cons = np.zeros((p1.shape[1]))
    inlier_dist = self.params['inlier_dist']

    """
    ************************************************
    *** TODO: write code to check consistency with H
    ************************************************
    """
    p1a = np.vstack([p1, np.ones(p1.shape[1])])
    p1b = np.expand_dims(p1a,1).T
    p2a = im_util.transform_coordinates(p1b, H).squeeze()[:,0:2]
    d = p2a - p2.T
    s = np.sqrt(np.sum(d * d, axis=1))
    cons = s < inlier_dist
    """
    ************************************************
    """

    return cons

  def compute_similarity(self,p1,p2):
    """
    Compute similarity transform between pairs of points

    Input: p1,p2=arrays of coordinates (2, n)
    n will be at least 2

    Output: Similarity matrix S (3, 3)

    Assume S maps from 1 to 2, i.e., hom(p2) = S hom(p1)
    """

    S = np.eye(3,3)

    """
    ****************************************************
    *** TODO: write code to compute similarity transform
    ****************************************************
    """

    # Some starter Code given , to help you with constructing the M matrix
    # using x/y instead of u/v

    x1 = p1[0, :]
    y1 = p1[1, :]
    x2 = p2[0, :]
    y2 = p2[1, :]
    n = x1.shape[0]
    M = np.zeros((n * 2, 4))
    b = np.zeros((n * 2))



    """
    ****************************************************
    """

    return S

  def ransac_similarity(self, ip1, ipm):
    """
    Find 2-view consistent matches under a Similarity transform

    Inputs: ip1=interest points (2, num_points)
            ipm=matching interest points (2, num_points)
            ip[0,:]=row coordinates, ip[1, :]=column coordinates

    Outputs: S_best=Similarity matrix (3,3)
             inliers_best=list of inliers indicated by true/false (num_points)
    """
    S_best=np.eye(3,3)
    inliers_best=[]

    """
    *****************************************************
    *** TODO: use ransac to find a similarity transform S
    *****************************************************
    
    """
    num_iterations=int(self.params['num_iterations'])
    n=2 # NOTE: n=2 because we are doing similarity transform
    idx=np.arange(ip1.shape[1])

    # perform num_iterations random selection of 2 points per RANSAC algorithm
    for i in range(num_iterations):
      # shuffling the points
      np.random.shuffle(idx)

      # selecting 2 random points (they are the first 'n' , but we shuffled everything)
      ip1_select = ip1[:, idx[:n]]
      ipm_select = ipm[:, idx[:n]]

      # TODO: compute the similarity transform using those n pairs of points using self.compute_similarity
      # Check which points are consistent from the whole set of points with the found similarity
      # update S_best to be the best transform and inliers_best to be the inlier points



    """
    *****************************************************
    """

    return S_best, inliers_best
