// Remember to compile with -pthread

#include <pthread.h>
#include <iostream>
#include <vector>
#include <algorithm>
#include <chrono>

using std::cout;
using std::endl;
using std::vector;
using std::random_shuffle;
using std::swap;
using std::is_sorted;
using namespace std::chrono;

const int ARR_SIZE = 20000;

void bubble_down(vector<int>::iterator cur, const vector<int>::iterator &lo) {
  while (cur > lo && *cur < *(cur - 1)) {
    swap(*cur, *(cur - 1));
    cur--;
  }
}

void merge(const vector<int>::iterator &lo, const vector<int>::iterator &hi) {
  auto cur = (hi - lo) / 2 + lo;
  while (cur < hi) {
    bubble_down(cur, lo);
    cur++;
  }
}

void sort(const vector<int>::iterator &lo, const vector<int>::iterator &hi) {
  if (hi - lo <= 1) {
    return;
  }

  auto mid = (hi - lo) / 2 + lo;
  sort(lo, mid);
  sort(mid, hi);
  merge(lo, hi);
}

struct merge_params {
  vector<int>::iterator lo;
  vector<int>::iterator hi;

  merge_params(vector<int>::iterator lo, vector<int>::iterator hi)
    : lo(lo), hi(hi) {}
};

void* thread_sort(void* arg) {
  merge_params *p = reinterpret_cast<merge_params*>(arg);
  sort(p->lo, p->hi);

  delete p;
  return NULL;
}

int main(int argc, char** argv) {
  vector<int> vec(ARR_SIZE);

  int i = 0;
  for (int& n : vec) {
    n = i++;
  }

  cout << "Shuffling..." << endl;
  random_shuffle(vec.begin(), vec.end());
  cout << (is_sorted(vec.begin(), vec.end()) ? "" : "not ") << "sorted " << endl;

  cout << "Sorting..." << endl;

  steady_clock::time_point start = steady_clock::now();

  auto lo = vec.begin();
  auto hi = vec.end();
  auto mid = (hi - lo) / 2 + lo;
  auto q1 = (mid - lo) / 2 + lo;
  auto q3 = (hi - mid) / 2 + mid;

  pthread_t t1;
  pthread_t t2;
  pthread_t t3;
  pthread_t t4;

  merge_params *p1 = new merge_params(lo, q1);
  merge_params *p2 = new merge_params(q1, mid);
  merge_params *p3 = new merge_params(mid, q1);
  merge_params *p4 = new merge_params(q3, hi);

  pthread_create(&t1, NULL, thread_sort, p1);
  pthread_create(&t2, NULL, thread_sort, p2);
  pthread_create(&t3, NULL, thread_sort, p3);
  pthread_create(&t4, NULL, thread_sort, p4);

  pthread_join(t1, NULL);
  pthread_join(t2, NULL);
  pthread_join(t3, NULL);
  pthread_join(t4, NULL);

  merge(lo, mid);
  merge(mid, hi);
  merge(lo, hi);

  steady_clock::time_point end = steady_clock::now();

  cout << (is_sorted(vec.begin(), mid) ? "" : "not ") << "sorted " << endl;
  duration<double> span = duration_cast<duration<double>>(end - start);
  cout << "Took: " << span.count() << " seconds" << endl;

  return 0;
}