#include "dmi_api.h"

#define THREAD_TOTAL 256
#define THREAD_NUM 8

#define S 271828183.0
#define A 1220703125.0

__thread double t_calc = 0;

double randlc(double XX);
double randlc_nth(int64_t nth);
void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum);

typedef struct scaleunit_t
{
  int64_t n;
  int64_t barrier_addr;
}scaleunit_t;

void DMI_main(int argc, char **argv)
{
  scaleunit_t scaleunit;
  int32_t init_node_num, thread_num;
  int64_t scaleunit_addr, barrier_addr, n;
  
  if(argc != 4)
    {
      outn("usage : %s init_node_num thread_num n", argv[0]);
      error();
    }
  init_node_num = atoi(argv[1]);
  thread_num = atoi(argv[2]);
  n = 1ULL << atoi(argv[3]);
  
  catch(DMI_mmap(&scaleunit_addr, sizeof(scaleunit_t), 1, NULL));
  catch(DMI_mmap(&barrier_addr, sizeof(DMI_barrier_t), 1, NULL));
  catch(DMI_barrier_init(barrier_addr));
  
  scaleunit.barrier_addr = barrier_addr;
  scaleunit.n = n;
  catch(DMI_write(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_rescale(scaleunit_addr, init_node_num, thread_num));
  
  catch(DMI_barrier_destroy(barrier_addr));
  catch(DMI_munmap(barrier_addr, NULL));
  catch(DMI_munmap(scaleunit_addr, NULL));
  return;
}

int32_t DMI_scaleunit(int my_rank, int pnum, int64_t scaleunit_addr)
{
  DMI_local_barrier_t local_barrier;
  scaleunit_t scaleunit;
  int32_t ret;
  int ql[10], ql_sum[10];
  int64_t i, left, right, n;
  double t, xk, yk, abs_xk, abs_yk, x_sum, y_sum, x_sum_sum, y_sum_sum, r, x, y, tmp;
  
  catch(DMI_read(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_GET, NULL));
  bind_to_cpu(my_rank % PROCNUM);
  n = scaleunit.n;
  
  catch(DMI_local_barrier_init(&local_barrier, scaleunit.barrier_addr));
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  time_lap(10);
  
  time_lap(60);
  left = n * my_rank / pnum;
  right = n * (my_rank + 1) / pnum;
  if(my_rank == pnum - 1)
    {
      right = n;
    }
  r = randlc_nth(left * 2);
  x_sum = 0;
  y_sum = 0;
  for(i = 0; i < 10; i++)
    {
      ql[i] = 0;
    }
  for(i = left; i < right; i++)
    {
      r = randlc(r);
      x = 2 * r - 1;
      r = randlc(r);
      y = 2 * r - 1;
      
      t = x * x + y * y;
      if(t <= 1)
        {
          tmp = sqrt(-2 * log(t) / t);
          xk = x * tmp;
          yk = y * tmp;
          x_sum += xk;
          y_sum += yk;
          
          abs_xk = fabs(xk);
          abs_yk = fabs(yk);
          if(abs_xk < abs_yk)
            {
              ql[(int)abs_yk]++;
            }
          else
            {
              ql[(int)abs_xk]++;
            }
        }
    }
  t_calc += time_diff(60);
  
  for(i = 0; i < 10; i++)
    {
      catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &ql[i], &ql_sum[i], DMI_OP_SUM, DMI_TYPE_INT));
    }
  catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &x_sum, &x_sum_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &y_sum, &y_sum_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  time_lap(11);
  
  if(my_rank == 0)
    {
      for(i = 0; i < 10; i++)
        {
          outn("ql[%d] = %d", i, ql_sum[i]);
        }
      outn("sum of x = %.16lf", x_sum_sum);
      outn("sum of y = %.16lf", y_sum_sum);
    }
  print_calctime(time_ref(11) - time_ref(10), t_calc, &local_barrier, my_rank, pnum);
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  catch(DMI_local_barrier_destroy(&local_barrier));
  return 0;
}

double randlc(double XX)
{
  static int KS = 0;
  static double	R23, R46, T23, T46;
  double T1, T2, T3, T4, A1, A2, X1, X2, Z, X;
  int i;
  
  if(KS == 0)
    {
      R23 = 1.0;
      R46 = 1.0;
      T23 = 1.0;
      T46 = 1.0;
      
      for(i = 1; i <= 23; i++)
        {
          R23 = 0.50 * R23;
          T23 = 2.0 * T23;
        }
      for(i = 1; i <= 46; i++)
        {
          R46 = 0.50 * R46;
          T46 = 2.0 * T46;
        }
      KS = 1;
    }
  
  T1 = R23 * A;
  A1 = (int)T1;
  A2 = A - T23 * A1;
  
  XX = T46 * XX;
  T1 = R23 * XX;
  X1 = (int)T1;
  X2 = XX - T23 * X1;
  
  T1 = A1 * X2 + A2 * X1;
  T2  = (int)(R23 * T1);
  Z = T1 - T23 * T2;
  T3 = T23 * Z + A2 * X2;
  T4  = (int)(R46 * T3);
  X = T3 - T46 * T4;
  return R46 * X;
} 

double randlc_nth(int64_t nth)
{
  static unsigned int KS = 0;
  static double	R23, R46, T23, T46;
  double T1, T2, T3, T4, T5, A1, A2, B1, B2, B, T;
  int64_t i, j, m, tmp;
  
  if(KS == 0)
    {
      R23 = 1.0;
      R46 = 1.0;
      T23 = 1.0;
      T46 = 1.0;
      
      for(i = 1; i <= 23; i++)
        {
          R23 = 0.50 * R23;
          T23 = 2.0 * T23;
        }
      for(i = 1; i <= 46; i++)
        {
          R46 = 0.50 * R46;
          T46 = 2.0 * T46;
        }
      KS = 1;
    }
  
  m = 0;
  tmp = nth;
  while(tmp)
    {
      m++;
      tmp >>= 1;
    }
  
  B = S;
  T = A;
  
  for(i = 0; i < m; i++)
    {
      j = nth / 2;
      if(2 * j != nth)
        {
          A1 = (int)(R23 * B);
          A2 = B - T23 * A1;
          B1 = (int)(R23 * T);
          B2 = T - T23 * B1;
          T1 = A1 * B2 + A2 * B1;
          T2 = (int)(R23 * T1);
          T3 = T1 - T23 * T2;
          T4 = T23 * T3 + A2 * B2;
          T5 = (int)(R46 * T4);
          B = T4 - T46 * T5;
        }
      
      A1 = (int)(R23 * T);
      A2 = T - T23 * A1;
      B1 = (int)(R23 * T);
      B2 = T - T23 * B1;
      T1 = A1 * B2 + A2 * B1;
      T2 = (int)(R23 * T1);
      T3 = T1 - T23 * T2;
      T4 = T23 * T3 + A2 * B2;
      T5 = (int)(R46 * T4);
      T = T4 - T46 * T5;
      
      nth = j;
    }
  return R46 * B;
}

void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum)
{
  double t_sum, t_max, t_min;

  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_max, DMI_OP_MAX, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_min, DMI_OP_MIN, DMI_TYPE_DOUBLE));
  if(my_rank == 0)
    {
      outn("pnum=%d time=%.12lf calc_max=%.12lf calc_min=%.12lf calc_avg=%.12lf comm=%.12lf", 
           pnum, total, t_max, t_min, t_sum / pnum, total - t_max);
    }
  return;
}
