// Remember to compile with -pthread

#include <pthread.h>
#include <iostream>
#include <vector>
#include <algorithm>
#include <chrono>

using std::cout;
using std::endl;
using std::vector;
using std::random_shuffle;
using std::swap;
using std::is_sorted;
using namespace std::chrono;

const int ARR_SIZE = 20000;

void bubble_down(vector<int>::iterator cur, const vector<int>::iterator &lo) {
  while (cur > lo && *cur < *(cur - 1)) {
    swap(*cur, *(cur - 1));
    cur--;
  }
}

void merge(const vector<int>::iterator &lo, const vector<int>::iterator &hi) {
  auto cur = (hi - lo) / 2 + lo;
  while (cur < hi) {
    bubble_down(cur, lo);
    cur++;
  }
}

struct merge_params {
  vector<int>::iterator lo;
  vector<int>::iterator hi;

  merge_params(vector<int>::iterator lo, vector<int>::iterator hi)
    : lo(lo), hi(hi) {}
};

void* thread_sort(void* arg) {
  merge_params *p = reinterpret_cast<merge_params*>(arg);
  sort(p->lo, p->hi);

  delete p;
  return NULL;
}

void sort(const vector<int>::iterator &lo, const vector<int>::iterator &hi) {
  if (hi - lo > 500) {
    auto mid = (hi - lo) / 2 + lo;

    pthread_t t1;
    pthread_t t2;

    merge_params *p1 = new merge_params(lo, mid);
    merge_params *p2 = new merge_params(mid, hi);

    pthread_create(&t1, NULL, thread_sort, p1);
    pthread_create(&t2, NULL, thread_sort, p2);

    pthread_join(t1, NULL);
    pthread_join(t2, NULL);
  }
  merge(lo, hi);
}

int main(int argc, char** argv) {
  vector<int> vec(ARR_SIZE);

  int i = 0;
  for (int& n : vec) {
    n = i++;
  }

  cout << "Shuffling..." << endl;
  random_shuffle(vec.begin(), vec.end());
  cout << (is_sorted(vec.begin(), vec.end()) ? "" : "not ") << "sorted " << endl;

  cout << "Sorting..." << endl;

  steady_clock::time_point start = steady_clock::now();
  sort(vec.begin(), vec.end());
  steady_clock::time_point end = steady_clock::now();

  cout << (is_sorted(vec.begin(), vec.end()) ? "" : "not ") << "sorted " << endl;
  duration<double> span = duration_cast<duration<double>>(end - start);
  cout << "Took: " << span.count() << " seconds" << endl;

  return 0;
}