#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <time.h>
#include <pthread.h>

int verbose = 0;   // Whether to print array before and after sorting
int nthreads = 2;  // Default number of threads is 2

/**
 * Helper function to report timing information.
 */
void output_time_difference(char* name, struct timeval* start,
                            struct timeval* end) {
  long secs_used =
      (end->tv_sec - start->tv_sec);  // avoid overflow by subtracting first
  long usecs_used = (end->tv_usec - start->tv_usec);
  double secs = secs_used + (double)usecs_used / 1000000;
  printf("%s took %f seconds\n", name, secs);
}

/**
 * Helper function to check whether the array is sorted.
 */
int check_sorted(int* arr, int len) {
  for (int i = 1; i < len; i++) {
    if (arr[i-1] > arr[i]) {
      return 0;
    }
  }
  return 1;
}

/**
 * Helper function to print out the array.
 */
void print_array(int* arr, int len) {
  for (int i = 0; i < len; i++) {
    printf("%d ", arr[i]);
  }
}

/**
 * Allocates and returns a new array with the same contents.
 * 
 * @param arr the array to copy
 * @param len the length of the array
 * @return int* : the new array. Should later be freed.
 */
int* copy_array(int* arr, int len) {
  int* cparr = malloc(sizeof(int) * len);
  for (int i = 0; i < len; i++) {
    cparr[i] = arr[i];
  }
  return cparr;
}

/**
 * Fills an array with random data.
 * 
 * @param arr the array to fill
 * @param len the length of the array
 * @param upperbd an upper bound on the random values to fill the array
 */
void fillWithRandom(int* arr, int len, int upperbd) {
  for (int i = 0; i < len; i++) {
    arr[i] = rand() % upperbd;
  }
}

/**
 * Merge: The main "workhorse" operation of the mergesort algorithm.
 * Assumes that the inputs satisfy the following:
 *      - start < mid < end
 *      - The "left half" arr[start], arr[start+1], ..., arr[mid] is sorted
 *      - The "right half" arr[mid+1], arr[mid+2], ..., arr[end] is sorted
 * The merge operation merges "left half" and "right half" so that
 *      - arr[start], ..., arr[end] is sorted
 * 
 * @param arr the array to merge
 * @param arr2 an auxiliary array (used as temporary storage)
 * @param start, @param mid, @param end - array indices
 */
void merge(int* arr, int* arr2, int start, int mid, int end) {
  int i = start;     // start of left half within arr
  int j = mid + 1;   // start of right half within arr
  int k = start;     // start of where to copy to in arr2
  // Merge left half and right half, copying into auxiliary array arr2
  while (i <= mid && j <= end) {
    if (arr[i] < arr[j]) {
      arr2[k++] = arr[i++];
    } else {
      arr2[k++] = arr[j++];
    }
  }
  // Either left or right half is done. Finish copying the other half.
  while (i <= mid) {
    arr2[k++] = arr[i++];
  }
  while (j <= end) {
    arr2[k++] = arr[j++];
  }
  // Copy everything from arr2 back to arr1.
  for (i = start; i <= end; i++) {
    arr[i] = arr2[i];
  }
}

/**
 * Mergesort: a recursive function to sort a part of an array based on 
 * the mergesort algorithm. Sorts the part from index start to index end.
 * 
 * @param arr the array to mergesort
 * @param arr2 an auxiliary array (used as temporary storage)
 * @param start the starting index of the part of arr to sort
 * @param end the ending index of the part of arr to sort
 */
void mergesort_rec(int* arr, int* arr2, int start, int end) {
  // Base case: part is length 0 or 1. No work to do.
  if (start >= end) {
    return;
  }
  int mid = (start + end) / 2;        // Compute middle index
  mergesort_rec(arr, arr2, start, mid);   // Recursively sort left half
  mergesort_rec(arr, arr2, mid + 1, end); // Recursively sort right half
  merge(arr, arr2, start, mid, end);  // Merge left and right halves
}

/**
 * Header function for the mergesort algorithm.
 * Creates the auxiliary array and starts recursive mergesort on the entire
 * array.
 * 
 * @param arr the array to mergesort
 * @param len the length of the array
 */
void mergesort_init(int* arr, int len) {
  int* arr2 = (int*) malloc(sizeof(int) * len); // auxiliary array
  mergesort_rec(arr, arr2, 0, len - 1);
  free(arr2); // free auxiliary array
}
			 

// Solution added
			 
struct workerinfo {
  int* arr;
  int* arr2;
  int start;
  int end;
};

void* worker(void* wptr) {
  struct workerinfo* w = (struct workerinfo*) wptr;
  mergesort_rec(w->arr, w->arr2, w->start, w->end);
  return NULL;
}


/**
 * Header function for the mergesort algorithm.
 * Creates the auxiliary array and starts recursive mergesort on the entire
 * array.
 * 
 * @param arr the array to mergesort
 * @param len the length of the array
 */
void mergesort_pthread_init(int* arr, int len) {
  int* arr2 = (int*) malloc(sizeof(int) * len); // auxiliary array

  int* starts = malloc(nthreads * sizeof(int));
  int* ends = malloc(nthreads * sizeof(int)); 
  for (int i = 0; i < nthreads; i++) {
    starts[i] = i * len / nthreads;
    ends[i] = (i+1) * len / nthreads - 1; 
  }
  pthread_t* tid = malloc(nthreads * sizeof(pthread_t));
  struct workerinfo* work = malloc(nthreads * sizeof(struct workerinfo));
  for (int i = 1; i < nthreads; i++) {
    work[i].arr = arr;
    work[i].arr2 = arr2;
    work[i].start = starts[i];
    work[i].end = ends[i];
    pthread_create(&tid[i], NULL, worker, &work[i]);
  }
  mergesort_rec(arr, arr2, starts[0], ends[0]);
  // Merge first two parts, then merge in 3rd part, then 4th, etc, 
  // when each is ready. Naive, but gets the job done.
  for (int i = 1; i < nthreads; i++) {
    pthread_join(tid[i], NULL);
    merge(arr, arr2, starts[0], ends[i-1], ends[i]);
  }
  free(starts);
  free(ends);
  free(tid);
  free(work);
  free(arr2);
}

/**
 * For testing sorting-algorithm code.
 * 
 * @param sort - the sorting algorithm to use, e.g., mergesort_init
 * @param arr - the array to sort
 * @param len - the length of the array
 * @param name - a name for this test
 */
void run_test(void (*sort)(int*, int), int* arr, int len, char* name) {
  struct timeval startTime, endTime;
  if (verbose) {                           // print out input array
    printf("Starting array: \n");
    print_array(arr, len);
    printf("\n");
  }
  gettimeofday(&startTime, NULL);
  sort(arr, len);                          // run the sort
  gettimeofday(&endTime, NULL);
  int wasSorted = check_sorted(arr, len);  // make sure it worked
  if (wasSorted) {
    printf("Sort successful!\n");
  } else {
    printf("Sort did not work.\n");
  }
  if (verbose) {                           // print out sorted array
    printf("Sorted array: \n");
    print_array(arr, len);
    printf("\n");
  }
  output_time_difference(name, &startTime, &endTime); // report timing
}

int main(int argc, char** argv) {
  if (argc < 2) {
    printf("Usage: %s len [verbose] [nthreads]\n", argv[0]);
    return -1;
  }

  // The second argument can be used to turn on/off verbosity.
  // 1 means print the array before & after sort, 0 suppresses this.
  if (argc >= 3) {
    verbose = atoi(argv[2]);
  }

  // The third argument can be used to change the number of threads.
  if (argc >= 4) {
    nthreads = atoi(argv[3]);
  }
  
  int len = atoi(argv[1]);
  int* arr = (int*) malloc(len * sizeof(int)); // array to sort
  fillWithRandom(arr, len, 10*len); // fill the array with some random data
  int* arr_serial = copy_array(arr, len);   // make a copy for each test
  int* arr_threaded = copy_array(arr, len);
  
  // Run tests!
  run_test(mergesort_init, arr_serial, len, "serial");
  run_test(mergesort_pthread_init, arr_threaded, len, "threaded");

  free(arr);
  free(arr_serial);
  free(arr_threaded);
  return 0;
}