#include "dmi_api.h"

#define ind(x, y, z) ((z) * (n + 2) * (n + 2) + (y) * (n + 2) + (x))

typedef struct scaleunit_t
{
  int32_t n;
  int32_t niter;
  int64_t p_addr;
  int64_t barrier_addr;
  int64_t edge_addr;
}scaleunit_t;

__thread double t_calc = 0;

void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum);

void DMI_main(int argc, char **argv)
{
  scaleunit_t scaleunit;
  int x, y, z, n, n2n2, node_num, thread_num, pnum, niter;
  int64_t p_addr, barrier_addr, scaleunit_addr, edge_addr;
  double *p;
  
  if(argc != 5)
    {
      outn("usage : %s node_num thread_num n niter", argv[0]);
      error();
    }
  node_num = atoi(argv[1]);
  thread_num = atoi(argv[2]);
  n = atoi(argv[3]);
  niter = atoi(argv[4]);
  pnum = node_num * thread_num;
  n2n2 = (n + 2) * (n + 2);
  
  catch(DMI_mmap(&scaleunit_addr, sizeof(scaleunit_t), 1, NULL));
  catch(DMI_mmap(&p_addr, n2n2 * sizeof(double), n + 2, NULL));
  catch(DMI_mmap(&edge_addr, 2 * n2n2 * sizeof(double), pnum, NULL));
  catch(DMI_mmap(&barrier_addr, sizeof(DMI_barrier_t), 1, NULL));
  catch(DMI_barrier_init(barrier_addr));
  
  p = (double*)my_malloc((n + 2) * n2n2 * sizeof(double));
  for(z = 0; z <= n + 1; z++)
    for(y = 0; y <= n + 1; y++)
      for(x = 0; x <= n + 1; x++)
        p[ind(x, y, z)] = 0;
  
  for(z = 1; z <= n; z++)
    for(x = 1; x <= n; x++)
      p[ind(x, 1, z)] = 1;
  
  catch(DMI_write(p_addr, (n + 2) * n2n2 * sizeof(double), p, DMI_EXCLUSIVE, NULL));
  my_free(p);
  
  scaleunit.p_addr = p_addr;
  scaleunit.barrier_addr = barrier_addr;
  scaleunit.edge_addr = edge_addr;
  scaleunit.n = n;
  scaleunit.niter = niter;
  catch(DMI_write(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_rescale(scaleunit_addr, node_num, thread_num));
  
  catch(DMI_barrier_destroy(barrier_addr));
  catch(DMI_munmap(barrier_addr, NULL));
  catch(DMI_munmap(edge_addr, NULL));
  catch(DMI_munmap(p_addr, NULL));
  catch(DMI_munmap(scaleunit_addr, NULL));
  return;
}

int32_t DMI_scaleunit(int my_rank, int pnum, int64_t scaleunit_addr)
{
  scaleunit_t scaleunit;
  int ret, x, y, z, n2n2, n, iter, left, right, niter;
  double sub_sum, sum, delta, resid;
  double *p, *q, *p_tail;
  DMI_local_barrier_t local_barrier;
  DMI_local_status_t status1, status2;
  
  catch(DMI_read(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_GET, NULL));
  bind_to_cpu(my_rank % PROCNUM);
  
  n = scaleunit.n;
  niter = scaleunit.niter;
  n2n2 = (n + 2) * (n + 2);
  left = n / pnum * my_rank;
  right = n / pnum * (my_rank + 1);
  if(my_rank == pnum - 1)
    {
      right = n;
    }
  
  catch(DMI_local_barrier_init(&local_barrier, scaleunit.barrier_addr));
  
  p = (double*)my_malloc((right - left + 2) * n2n2 * sizeof(double));
  q = (double*)my_malloc((right - left + 2) * n2n2 * sizeof(double));
  p_tail = p + (right - left + 2) * n2n2;
  
  catch(DMI_read(scaleunit.p_addr + left * n2n2 * sizeof(double), (right - left) * n2n2 * sizeof(double), p + n2n2, DMI_GET, NULL));
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  
  time_lap(20);
  for(iter = 0; iter < niter; iter++)
    {
      time_lap(10);
      
      catch(DMI_write(scaleunit.edge_addr + (my_rank * 2) * n2n2 * sizeof(double), n2n2 * sizeof(double), p + n2n2, DMI_EXCLUSIVE, &status1));
      catch(DMI_write(scaleunit.edge_addr + (my_rank * 2 + 1) * n2n2 * sizeof(double), n2n2 * sizeof(double), p_tail - 2 * n2n2, DMI_EXCLUSIVE, &status2));
      DMI_wait(&status1, &ret);
      catch(ret);
      DMI_wait(&status2, &ret);
      catch(ret);
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
      
      if(my_rank >= 1)
        {
          catch(DMI_read(scaleunit.edge_addr + ((my_rank - 1) * 2 + 1) * n2n2 * sizeof(double), n2n2 * sizeof(double), p, DMI_GET, &status1));
        }
      if(my_rank <= pnum - 2)
        {
          catch(DMI_read(scaleunit.edge_addr + ((my_rank + 1) * 2) * n2n2 * sizeof(double), n2n2 * sizeof(double), p_tail - n2n2, DMI_GET, &status2));
        }
      if(my_rank >= 1)
        {
          DMI_wait(&status1, &ret);
          catch(ret);
        }
      if(my_rank <= pnum - 2)
        {
          DMI_wait(&status2, &ret);
          catch(ret);
        }
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
      time_lap(11);
      time_lap(60);
      
      sub_sum = 0;
      for(z = 1; z <= right - left; z++)
        for(y = 1; y <= n; y++)
          for(x = 1; x <= n; x++)
            {
              q[ind(x, y, z)] = (p[ind(x, y, z)]
                                 + p[ind(x - 1, y, z)] + p[ind(x + 1, y, z)]
                                 + p[ind(x, y - 1, z)] + p[ind(x, y + 1, z)]
                                 + p[ind(x, y, z - 1)] + p[ind(x, y, z + 1)]
                                 + p[ind(x - 1, y - 1, z)] + p[ind(x - 1, y + 1, z)]
                                 + p[ind(x - 1, y, z - 1)] + p[ind(x - 1, y, z + 1)]
                                 + p[ind(x + 1, y - 1, z)] + p[ind(x + 1, y + 1, z)]
                                 + p[ind(x + 1, y, z - 1)] + p[ind(x + 1, y, z + 1)]
                                 + p[ind(x, y - 1, z - 1)] + p[ind(x, y - 1, z + 1)]
                                 + p[ind(x, y + 1, z - 1)] + p[ind(x, y + 1, z + 1)]
                                 + p[ind(x - 1, y - 1, z - 1)] + p[ind(x - 1, y - 1, z + 1)]
                                 + p[ind(x - 1, y + 1, z - 1)] + p[ind(x - 1, y + 1, z + 1)]
                                 + p[ind(x + 1, y - 1, z - 1)] + p[ind(x + 1, y - 1, z + 1)]
                                 + p[ind(x + 1, y + 1, z - 1)] + p[ind(x + 1, y + 1, z + 1)]
                                 ) / 27;
              delta = fabs(q[ind(x, y, z)] - p[ind(x, y, z)]);
              sub_sum += delta * delta;
            }
      
      memcpy(p + n2n2, q + n2n2, (right - left) * n2n2 * sizeof(double));
      t_calc += time_diff(60);
      
      //outn("left=%d right=%d comptime=%.12lf", left, right, time_diff(11));
      catch(DMI_local_barrier_allreduce(&local_barrier, pnum, &sub_sum, &sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
      resid = sqrt(sum / n / n / n);
      time_lap(12);
      
      if(my_rank == 0)
        {
          outn("iteration=%d time=%.12lf calc=%.12lf comm=%.12lf resid=%e", 
               iter, time_ref(12) - time_ref(10), time_ref(12) - time_ref(11), time_ref(11) - time_ref(10), resid);
        }
    }
  print_calctime(time_diff(20) / niter, t_calc / niter, &local_barrier, my_rank, pnum);
  
  catch(DMI_local_barrier_destroy(&local_barrier));
  my_free(q);
  my_free(p);
  return 0;
}

void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum)
{
  double t_sum, t_max, t_min;

  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_max, DMI_OP_MAX, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_min, DMI_OP_MIN, DMI_TYPE_DOUBLE));
  if(my_rank == 0)
    {
      outn("pnum=%d time=%.12lf calc_max=%.12lf calc_min=%.12lf calc_avg=%.12lf comm=%.12lf", 
           pnum, total, t_max, t_min, t_sum / pnum, total - t_max);
    }
  return;
}
