#include "dmi_api.h"

#define G 6.673e-11
#define INTERVAL 1.0e-16
#define ABG 6.022e23
#define EPS 1.65e-21
#define R2 19.4e-11

typedef struct scaleunit_t
{
  int32_t n;
  int32_t niter;
  int64_t xyz_addr;
  int64_t barrier_addr;
}scaleunit_t;

typedef struct obj_t
{
  double x;
  double y;
  double z;
  double vx;
  double vy;
  double vz;
  double ax;
  double ay;
  double az;
  double pe;
  double ke;
  double m;
}obj_t;

__thread double t_calc = 0;

void print_matrix(double *matrix, int32_t n);
double sumof_matrix(double *matrix, int32_t n);
void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum);

void DMI_main(int argc, char **argv)
{
  scaleunit_t scaleunit;
  int32_t n, init_node_num, thread_num, pnum, n3, niter;
  int64_t scaleunit_addr, xyz_addr, barrier_addr;
  
  if(argc != 5)
    {
      outn("usage : %s init_node_num thread_num n niter", argv[0]);
      error();
    }
  
  init_node_num = atoi(argv[1]);
  thread_num = atoi(argv[2]);
  n = atoi(argv[3]);
  niter = atoi(argv[4]);
  n3 = n * n * n;
  pnum = init_node_num * thread_num;
  if(n3 % pnum != 0)
    {
      outn("n3 %% pnum != 0");
      error();
    }
  
  catch(DMI_mmap(&scaleunit_addr, sizeof(scaleunit_t), 1, NULL));
  catch(DMI_mmap(&xyz_addr, n3 / init_node_num * 3 * sizeof(double), init_node_num, NULL));
  catch(DMI_mmap(&barrier_addr, sizeof(DMI_barrier_t), 1, NULL));
  catch(DMI_barrier_init(barrier_addr));
  
  scaleunit.n = n;
  scaleunit.niter = niter;
  scaleunit.barrier_addr = barrier_addr;
  scaleunit.xyz_addr = xyz_addr;
  catch(DMI_write(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_rescale(scaleunit_addr, init_node_num, thread_num));
    
  catch(DMI_munmap(xyz_addr, NULL));
  catch(DMI_barrier_destroy(barrier_addr));
  catch(DMI_munmap(barrier_addr, NULL));
  catch(DMI_munmap(scaleunit_addr, NULL));
  return;
}

int32_t DMI_scaleunit(int my_rank, int pnum, int64_t scaleunit_addr)
{
  obj_t *objs;
  int32_t i, j, k, N, L, N3, NITER, tmp, left, right, iter;
  int64_t xyz_addr;
  scaleunit_t scaleunit;
  double f, ke, pe, ke_all, pe_all, sigma, t, r, r2, t1, t2, t3, t4;
  double *xyz;
  DMI_local_barrier_t local_barrier;
  
  catch(DMI_read(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_GET, NULL));
  bind_to_cpu(my_rank % PROCNUM);
  
  N = scaleunit.n;
  NITER = scaleunit.niter;
  xyz_addr = scaleunit.xyz_addr;
  N3 = N * N * N;
  L = R2 * (N - 1);
  
  catch(DMI_local_barrier_init(&local_barrier, scaleunit.barrier_addr));
  
  objs = (obj_t*)my_malloc(N3 * sizeof(obj_t));
  xyz = (double*)my_malloc(N3 * 3 * sizeof(double));
  
  for(i = 0; i < N; i++)
    {
      for(j = 0; j < N; j++)
        {
          for(k = 0; k < N; k++)
            {
              tmp = N * N * i + N * j + k;
              objs[tmp].x = R2 * i;
              objs[tmp].y = R2 * j;
              objs[tmp].z = R2 * k;
              objs[tmp].vx = 0;
              objs[tmp].vy = 0;
              objs[tmp].vz = 0;
              objs[tmp].ax = 0;
              objs[tmp].ay = 0;
              objs[tmp].az = 0;
              objs[tmp].pe = 0;
              objs[tmp].ke = 0;
              objs[tmp].m = 39.948 / ABG / 1000.0;
            }
        }
    }
  for(i = 0; i < N3 * 3; i++)
    {
      xyz[i] = 0;
    }
  
  left = N3 / pnum * my_rank;
  right = N3 / pnum * (my_rank + 1);
  sigma = R2 / cbrt(sqrt(2));
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  t3 = get_time();
  t = 0;
  for(iter = 0; iter < NITER; iter++)
    {
      t1 = get_time();
      
      time_lap(60);
      time_lap(30);
      pe = 0;
      for(i = left; i < right; i++)
        {
          objs[i].ax = 0;
          objs[i].ay = 0;
          objs[i].az = 0;
          objs[i].pe = 0;
          for(j = 0; j < N3; j++)
            {
              if(i != j)
                {
                  r2 = (objs[i].x - objs[j].x) * (objs[i].x - objs[j].x)
                    + (objs[i].y - objs[j].y) * (objs[i].y - objs[j].y)
                    + (objs[i].z - objs[j].z) * (objs[i].z - objs[j].z);
                  r = sqrt(r2);
                  
                  f = -24.0 * EPS * (2.0 * pow(sigma / r, 12.0) - pow(sigma / r, 6.0)) / r2;
                  
                  objs[i].ax += f * (objs[j].x - objs[i].x) / objs[i].m;
                  objs[i].ay += f * (objs[j].y - objs[i].y) / objs[i].m;
                  objs[i].az += f * (objs[j].z - objs[i].z) / objs[i].m;
                  
                  if(i < j)
                    {
                      objs[i].pe += 4 * EPS * (pow(sigma / r, 12.0) - pow(sigma/r, 6.0));
                    }
                }
            }
          pe += objs[i].pe;
        }
      
      ke = 0;
      for(i = left; i < right; i++)
        {
          tmp = INTERVAL * INTERVAL / 2;
          objs[i].x += objs[i].vx * INTERVAL + objs[i].ax * tmp;
          objs[i].y += objs[i].vy * INTERVAL + objs[i].ay * tmp;
          objs[i].z += objs[i].vz * INTERVAL + objs[i].az * tmp;
          
          objs[i].vx += objs[i].ax * INTERVAL;
          objs[i].vy += objs[i].ay * INTERVAL;
          objs[i].vz += objs[i].az * INTERVAL;
          
          objs[i].ke = objs[i].m * (objs[i].vx * objs[i].vx + objs[i].vy * objs[i].vy + objs[i].vz * objs[i].vz) / 2;
          
          ke += objs[i].ke;
          
          tmp = (i - left) * 3;
          xyz[tmp] = objs[i].x;
          xyz[tmp + 1] = objs[i].y;
          xyz[tmp + 2] = objs[i].z;
        }
      t_calc += time_diff(60);
      time_lap(31);
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
      
      catch(DMI_write(xyz_addr + left * 3 * sizeof(double), (right - left) * 3 * sizeof(double), xyz, DMI_EXCLUSIVE, NULL));
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
      
      _flag3 = TRUE;
      catch(DMI_read(xyz_addr, N3 * 3 * sizeof(double), xyz, DMI_INVALIDATE, NULL));
      _flag3 = FALSE;
      
      for(i = 0; i < N3; i++)
        {
          tmp = i * 3;
          objs[i].x = xyz[tmp];
          objs[i].y = xyz[tmp + 1];
          objs[i].z = xyz[tmp + 2];
        }
      
      catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &ke, &ke_all, DMI_OP_SUM, DMI_TYPE_DOUBLE));
      catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &pe, &pe_all, DMI_OP_SUM, DMI_TYPE_DOUBLE));
      
      t2 = get_time();
      if(my_rank == 0)
        {
          outn("iteration=%d time=%.12lf calc=%.12lf comm=%.12lf ke=%.12e pe=%.12e tot=%.12e", 
               iter, t2 - t1, time_ref(31) - time_ref(30), (t2 - t1) - (time_ref(31) - time_ref(30)), ke_all, pe_all, ke_all + pe_all);
        }
      
      t += INTERVAL;
    }
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  t4 = get_time();
  print_calctime((t4 - t3) / NITER, t_calc / NITER, &local_barrier, my_rank, pnum);
  
  catch(DMI_local_barrier_destroy(&local_barrier));
  
  my_free(xyz);
  my_free(objs);
  return 0;
}

void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum)
{
  double t_sum, t_max, t_min;

  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_max, DMI_OP_MAX, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_min, DMI_OP_MIN, DMI_TYPE_DOUBLE));
  if(my_rank == 0)
    {
      outn("pnum=%d time=%.12lf calc_max=%.12lf calc_min=%.12lf calc_avg=%.12lf comm=%.12lf", 
           pnum, total, t_max, t_min, t_sum / pnum, total - t_max);
    }
  return;
}
