#include "dmi_api.h"

#define THRESHOLD 10

typedef struct scaleunit_t
{
  int32_t async_flag;
  int32_t n;
  int32_t nn;
  int64_t global_addr;
  int64_t sampling_addr;
  int64_t offset_addr;
  int64_t barrier_addr;
}scaleunit_t;

__thread double t_calc = 0;

void print(int n, int *x, int my_rank, char *str);
void inssort(int n, int *x);
void quicksort(int n, int *x);
void roughsort(int n, int *x, int nn, int *xx);
void check(int argc, char **argv);
void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum);

void DMI_main(int argc, char **argv)
{
  scaleunit_t scaleunit;
  int32_t n, nn, init_node_num, thread_num, pnum, async_flag;
  int64_t scaleunit_addr, global_addr, sampling_addr, offset_addr, barrier_addr;
  
  if(argc != 6)
    {
      errn("usage : %s init_node_num thread_num n nn async_flag", argv[0]);
      error();
    }
  
  init_node_num = atoi(argv[1]);
  thread_num = atoi(argv[2]);
  n = atoi(argv[3]);
  nn = atoi(argv[4]);
  async_flag = atoi(argv[5]);
  pnum = init_node_num * thread_num;
  if(n % pnum != 0)
    {
      errn("n % pnum != 0");
      return;
    }
  n /= pnum;
  
  catch(DMI_mmap(&scaleunit_addr, sizeof(scaleunit_t), 1, NULL));
  catch(DMI_mmap(&barrier_addr, sizeof(DMI_barrier_t), 1, NULL));
  catch(DMI_mmap(&global_addr, n * sizeof(int32_t), pnum, NULL));
  catch(DMI_mmap(&sampling_addr, nn * sizeof(int32_t) * pnum, 1, NULL));
  catch(DMI_mmap(&offset_addr, (pnum + 1) * sizeof(int32_t), pnum, NULL));
  catch(DMI_barrier_init(barrier_addr));
  
  scaleunit.n = n;
  scaleunit.nn = nn;
  scaleunit.global_addr = global_addr;
  scaleunit.sampling_addr = sampling_addr;
  scaleunit.offset_addr = offset_addr;
  scaleunit.barrier_addr = barrier_addr;
  catch(DMI_write(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_rescale(scaleunit_addr, init_node_num, thread_num));
  
  catch(DMI_barrier_destroy(barrier_addr));
  catch(DMI_munmap(offset_addr, NULL));
  catch(DMI_munmap(sampling_addr, NULL));
  catch(DMI_munmap(global_addr, NULL));
  catch(DMI_munmap(barrier_addr, NULL));
  catch(DMI_munmap(scaleunit_addr, NULL));
  return;
}

int32_t DMI_scaleunit(int my_rank, int pnum, int64_t scaleunit_addr)
{
  int n, nn, n2, i, k, rank, flag, ret;
  int *x, *ran, *rans, *x2, *xx, *x2_bounds, *offsets;
  int64_t global_addr, sampling_addr, offset_addr, barrier_addr;
  int64_t *addrs, *ptr_offsets, *sizes;
  double sum, sum2, sum_sum, sum2_sum;
  scaleunit_t scaleunit;
  DMI_local_barrier_t local_barrier;
  DMI_local_status_t *statuses;
  
  catch(DMI_read(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_GET, NULL));
  bind_to_cpu(my_rank % PROCNUM);
  
  n = scaleunit.n;
  nn = scaleunit.nn;
  global_addr = scaleunit.global_addr;
  sampling_addr = scaleunit.sampling_addr;
  offset_addr = scaleunit.offset_addr;
  barrier_addr = scaleunit.barrier_addr;
  
  catch(DMI_local_barrier_init(&local_barrier, barrier_addr));
  
  mrand_init(my_rank);
  
  x = (int32_t*)my_malloc(n * sizeof(int32_t));
  sum = 0;
  for(i = 0; i < n; i++)
    {
      x[i] = mrand_int(0, (1 << 30) - 1);
      sum += x[i];
    }
  
  //print(n, x, my_rank, "x   : ");
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  time_lap(1);
  
  time_lap(60);
  ran = (int*)my_malloc(nn * sizeof(int));
  for(i = 0; i < nn; i++)
    {
      ran[i] = x[mrand_int(0, n - 1)];
    }
  t_calc += time_diff(60);
  
  //print(nn, ran, my_rank, "ran : ");
  
  catch(DMI_write(sampling_addr + my_rank * nn * sizeof(int32_t), nn * sizeof(int32_t), ran, DMI_PUT, NULL));
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  xx = (int*)my_malloc((pnum + 1) * sizeof(int));
  if(my_rank == 0)
    {
      rans = (int*)my_malloc(nn * pnum * sizeof(int));
      
      catch(DMI_read(sampling_addr, pnum * nn * sizeof(int32_t), rans, DMI_GET, NULL));
      
      time_lap(60);
      quicksort(nn * pnum, rans);
      t_calc += time_diff(60);
      
      k = 0;
      for(i = nn; i < nn * pnum; i += nn)
        {
          xx[k++] = rans[i];
        }
      assert(k == pnum - 1);
      
      catch(DMI_write(sampling_addr, (pnum - 1) * sizeof(int32_t), xx, DMI_PUT, NULL));
      
      my_free(rans);
    }
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  catch(DMI_read(sampling_addr, (pnum - 1) * sizeof(int32_t), xx, DMI_GET, NULL));
  
  time_lap(2);
  
  //print(pnum - 1, xx, my_rank, "xx  : ");
  
  time_lap(60);
  roughsort(n, x, pnum - 1, xx);
  t_calc += time_diff(60);
  
  time_lap(3);
  
  xx[pnum - 1] = n;
  for(rank = pnum - 1; rank >= 0; rank--)
    {
      xx[rank + 1] = xx[rank];
    }
  xx[0] = 0;
  //print(pnum + 1, xx, my_rank, "xx  : ");
  //print(n, x, my_rank, "x   : ");
  
  catch(DMI_write(global_addr + my_rank * n * sizeof(int32_t), n * sizeof(int32_t), x, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_write(offset_addr + my_rank * (pnum + 1) * sizeof(int32_t), (pnum + 1) * sizeof(int32_t), xx, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  offsets = (int32_t*)my_malloc(pnum * 2 * sizeof(int32_t));
  sizes = (int64_t*)my_malloc(pnum * sizeof(int64_t));
  ptr_offsets = (int64_t*)my_malloc((pnum + 1) * sizeof(int64_t));
  addrs = (int64_t*)my_malloc((pnum + 1) * sizeof(int64_t));
  statuses = (status_t*)my_malloc(pnum * sizeof(status_t));
  
  if(scaleunit.async_flag)
    {
      for(rank = 0; rank < pnum; rank++)
        {
          catch(DMI_read(offset_addr + (rank * (pnum + 1) + my_rank) * sizeof(int32_t), sizeof(int32_t) * 2, &offsets[rank * 2], DMI_GET, &statuses[rank]));
        }
      
      n2 = 0;
      ptr_offsets[0] = 0;
      for(rank = 0; rank < pnum; rank++)
        {
          DMI_wait(&statuses[rank], &ret);
          catch(ret);
          
          addrs[rank] = rank * n * sizeof(int32_t) + offsets[rank * 2] * sizeof(int32_t);
          sizes[rank] = (offsets[rank * 2 + 1] - offsets[rank * 2]) * sizeof(int32_t);
          ptr_offsets[rank + 1] = ptr_offsets[rank] + sizes[rank];
          n2 += offsets[rank * 2 + 1] - offsets[rank * 2];
        }
      
      x2 = (int32_t*)my_malloc(n2 * sizeof(int32_t));
      
      for(rank = 0; rank < pnum; rank++)
        {
          catch(DMI_read(global_addr + addrs[rank], sizes[rank], x2 + ptr_offsets[rank] / sizeof(int32_t), DMI_GET, &statuses[rank]));
        }
      
      for(rank = 0; rank < pnum; rank++)
        {
          DMI_wait(&statuses[rank], &ret);
          catch(ret);
        }
    }
  else
    {
      n2 = 0;
      ptr_offsets[0] = 0;
      for(rank = 0; rank < pnum; rank++)
        {
          catch(DMI_read(offset_addr + (rank * (pnum + 1) + my_rank) * sizeof(int32_t), sizeof(int32_t) * 2, &offsets[rank * 2], DMI_GET, NULL));
          
          addrs[rank] = rank * n * sizeof(int32_t) + offsets[rank * 2] * sizeof(int32_t);
          sizes[rank] = (offsets[rank * 2 + 1] - offsets[rank * 2]) * sizeof(int32_t);
          ptr_offsets[rank + 1] = ptr_offsets[rank] + sizes[rank];
          n2 += offsets[rank * 2 + 1] - offsets[rank * 2];
        }
      
      x2 = (int32_t*)my_malloc(n2 * sizeof(int32_t));
      
      for(rank = 0; rank < pnum; rank++)
        {
          catch(DMI_read(global_addr + addrs[rank], sizes[rank], x2 + ptr_offsets[rank] / sizeof(int32_t), DMI_GET, NULL));
        }
    }
  
  //print(pnum, rc, my_rank, "rc  : ");
  
  time_lap(4);
  
  //print(n2, x2, my_rank, "x2  : ");
  
  time_lap(60);
  quicksort(n2, x2);
  t_calc += time_diff(60);
  
  //print(n2, x2, my_rank, "x2  : ");
  
  my_free(x);
  my_free(xx);
  my_free(ran);
  my_free(statuses);
  my_free(addrs);
  my_free(ptr_offsets);
  my_free(sizes);
  my_free(offsets);
  
  time_lap(5);
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
    
  time_lap(6);
  
  sum2 = 0;
  for(i = 0; i < n2; i++)
    {
      sum2 += x2[i];
      if(i != 0 && x2[i - 1] > x2[i])
        {
          sum2 = -100000000;
        }
    }
  
  catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &sum, &sum_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &sum2, &sum2_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  
  outn("rank %d : sampling=%.12lf rough=%.12lf alltoall=%.12lf qsort=%.12lf"
       , my_rank, time_ref(2) - time_ref(1), time_ref(3) - time_ref(2), time_ref(4) - time_ref(3), time_ref(5) - time_ref(4));
  
  catch(DMI_write(sampling_addr + my_rank * 2 * sizeof(int32_t), sizeof(int32_t), &x2[0], DMI_PUT, NULL));
  catch(DMI_write(sampling_addr + (my_rank * 2 + 1) * sizeof(int32_t), sizeof(int32_t), &x2[n2 - 1], DMI_PUT, NULL));
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  if(my_rank == 0)
    {
      x2_bounds = (int32_t*)my_malloc(pnum * 2 * sizeof(int32_t));
      
      catch(DMI_read(sampling_addr, pnum * 2 * sizeof(int32_t), x2_bounds, DMI_GET, NULL));
      
      flag = 0;
      if(fabs(sum_sum / sum2_sum - 1) > EPS6)
        {
          flag = 1;
          outn("ERROR : sum_sum(%e) != sum2_sum(%e)", sum_sum, sum2_sum);
        }
      for(i = 1; i < pnum * 2; i++)
        {
          //out("%d ", x2_bounds[i - 1]);
          if(x2_bounds[i - 1] > x2_bounds[i])
            {
              flag = 1;
              outn("ERROR : x2_bounds[i - 1] > x2_bounds[i]");
            }
        }
      //outn("%d", x2_bounds[i - 1]);
      if(flag == 0)
        {
          outn("OK : %e", sum_sum);
        }
      
      my_free(x2_bounds);
    }
  print_calctime(time_ref(6) - time_ref(1), t_calc, &local_barrier, my_rank, pnum);
  
  my_free(x2);
  
  catch(DMI_local_barrier_destroy(&local_barrier));
  return 0;  
}

void print(int n, int *x, int my_rank, char *str)
{
  int i;
  
  printf("[%d] ", my_rank);
  printf("%s ", str);
  for(i = 0; i < n; i++)
    {
      printf("%d ", x[i]);
    }
  outn("");
  return;
}

void inssort(int n, int *x)
{
  int i, j, tmp;
  
  for(i = 1; i < n; i++)
    {
      tmp = x[i];
      for(j = i - 1; j >= 0 && x[j] > tmp; j--)
        {
          x[j + 1] = x[j];
        }
      x[j + 1] = tmp;
    }
  return;
}

void quicksort(int n, int *x)
{
  int i, j, left, right, p, pivot, tmp;
  int lstack[32], rstack[32];
  
  left = 0;
  right = n - 1;
  p = 0;
  while(1)
    {
      if(right - left <= THRESHOLD)
        {
          if(p == 0) break;
          p--;
          left = lstack[p];
          right = rstack[p];
        }
      pivot = x[(left + right) / 2];
      i = left;
      j = right;
      while(1)
        {
          while(x[i] < pivot) i++;
          while(pivot < x[j]) j--;
          if(i >= j) break;
          tmp = x[i];
          x[i] = x[j];
          x[j] = tmp;
          i++;
          j--;
        }
      if(i - left > right - j)
        {
          if(i - left > THRESHOLD)
            {
              lstack[p] = left;
              rstack[p] = i - 1;
              p++;
            }
          left = j + 1;
        }
      else
        {
          if(right - j > THRESHOLD)
            {
              lstack[p] = j + 1;
              rstack[p] = right;
              p++;
            }
          right = i - 1;
        }
    }
  inssort(n, x);
  return;
}

void roughsort(int n, int *x, int nn, int *xx)
{
  int i, j, left, right, xxleft, xxright, xxmid, p, pivot, tmp;
  int lstack[32], rstack[32], xxlstack[32], xxrstack[32];
  
  p = 0;
  left = 0;
  right = n - 1;
  xxleft = 0;
  xxright = nn - 1;
  while(1)
    {
      if(xxleft > xxright)
        {
          if(p == 0) break;
          p--;
          left = lstack[p];
          right = rstack[p];
          xxleft = xxlstack[p];
          xxright = xxrstack[p];
        }
      xxmid = (xxleft + xxright) / 2;
      pivot = xx[xxmid];
      i = left;
      j = right;
      while(1)
        {
          while(i < right && x[i] < pivot) i++;
          while(left <= j && pivot < x[j]) j--;
          if(i >= j) break;
          tmp = x[i];
          x[i] = x[j];
          x[j] = tmp;
          i++;
          j--;
        }
      xx[xxmid] = j + 1;
      if(xxleft <= xxmid - 1)
        {
          lstack[p] = left;
          rstack[p] = i - 1;
          xxlstack[p] = xxleft;
          xxrstack[p] = xxmid - 1;
          p++;
        }
      left = j + 1;
      xxleft = xxmid + 1;
    }
  return;
}

void check(int argc, char **argv)
{
  int n, nn, i, k;
  int x[200], xx[200];
  
  mrand_init((unsigned)time(NULL));
  
  if(argc != 3)
    {
      errn("usage : %s n nn", argv[0]);
      return;
    }
  n = atoi(argv[1]);
  nn = atoi(argv[2]);
  
  for(i = 0; i < n; i++)
    {
      x[i] = (int)(30 * ((double)rand() / RAND_MAX));
    }
  for(i = 0; i < nn; i++)
    {
      xx[i] = x[(int)(n * ((double)rand() / RAND_MAX))];
    }
  quicksort(nn, xx);
  
  out(" x : ");
  for(i = 0; i < n; i++)
    {
      out("%02d ", x[i]);
    }
  outn("");
  
  out("xx : ");
  for(i = 0; i < nn; i++)
    {
      out("%d ", xx[i]);
    }
  outn("");
  
  roughsort(n, x, nn, xx);
  
  out("rx : ");
  for(i = 0; i < n; i++)
    {
      out("%02d ", x[i]);
    }
  outn("");
  out("     ");
  k = 0;
  for(i = 0; i < n; i++)
    {
      if(k >= nn || i != xx[k])
        {
          out("   ");
        }
      else
        {
          out(" ^ ");
          k++;
        }
    }
  outn("");
  
  
  quicksort(n, x);
  
  out("sx : ");
  for(i = 0; i < n; i++)
    {
      out("%02d ", x[i]);
    }
  outn("");
  return;
}

void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum)
{
  double t_sum, t_max, t_min;

  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_max, DMI_OP_MAX, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_min, DMI_OP_MIN, DMI_TYPE_DOUBLE));
  if(my_rank == 0)
    {
      outn("pnum=%d time=%.12lf calc_max=%.12lf calc_min=%.12lf calc_avg=%.12lf comm=%.12lf", 
           pnum, total, t_max, t_min, t_sum / pnum, total - t_max);
    }
  return;
}
