#include "dmi_api.h"

typedef struct scaleunit_t
{
  int32_t pnum;
  int32_t n;
  int64_t a_addr;
  int64_t b_addr;
  int64_t c_addr;
  int64_t barrier_addr;
}scaleunit_t;

__thread double t_calc = 0;

void print_matrix(double *matrix, int32_t n);
double sumof_matrix(double *matrix, int32_t n);
void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum);

void DMI_main(int argc, char **argv)
{
  scaleunit_t scaleunit;
  int32_t n, init_node_num, thread_num, pnum, q, nq;
  int64_t scaleunit_addr, a_addr, b_addr, c_addr, barrier_addr;
  
  if(argc != 5)
    {
      outn("usage : %s init_node_num thread_num pnum n", argv[0]);
      error();
    }
  
  init_node_num = atoi(argv[1]);
  thread_num = atoi(argv[2]);
  pnum = atoi(argv[3]);
  n = atoi(argv[4]);
  if(pnum > init_node_num * thread_num)
    {
      outn("pnum > init_node_num * thread_num");
      error();
    }
  if(fabs(sqrt(pnum) - (int)sqrt(pnum)) > EPS12)
    {
      outn("pnum is not a square number", argv[0]);
      error();
    }
  q = (int)sqrt(pnum);
  nq = n / q;
  
  catch(DMI_mmap(&scaleunit_addr, sizeof(scaleunit_t), 1, NULL));
  catch(DMI_mmap(&barrier_addr, sizeof(DMI_barrier_t), 1, NULL));
  catch(DMI_mmap(&a_addr, nq * nq * sizeof(double), pnum, NULL));
  catch(DMI_mmap(&b_addr, nq * nq * sizeof(double), pnum, NULL));
  catch(DMI_mmap(&c_addr, n * n * sizeof(double), 1, NULL));
  catch(DMI_barrier_init(barrier_addr));
  
  scaleunit.pnum = pnum;
  scaleunit.n = n;
  scaleunit.barrier_addr = barrier_addr;
  scaleunit.a_addr = a_addr;
  scaleunit.b_addr = b_addr;
  scaleunit.c_addr = c_addr;
  catch(DMI_write(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_rescale(scaleunit_addr, init_node_num, thread_num));
  
  catch(DMI_barrier_destroy(barrier_addr));
  catch(DMI_munmap(c_addr, NULL));
  catch(DMI_munmap(b_addr, NULL));
  catch(DMI_munmap(a_addr, NULL));
  catch(DMI_munmap(barrier_addr, NULL));
  catch(DMI_munmap(scaleunit_addr, NULL));
  return;
}

int32_t DMI_scaleunit(int my_rank, int true_pnum, int64_t scaleunit_addr)
{
  int32_t i, j, k, n, x, y, z, q, nq, stage, pnum;
  int64_t a_addr, b_addr, c_addr;
  scaleunit_t scaleunit;
  double sum;
  double *original_a, *original_b, *original_c, *local_a, *local_b, *local_c;
  DMI_local_barrier_t local_barrier;
  
  catch(DMI_read(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_GET, NULL));
  bind_to_cpu(my_rank % PROCNUM);
  
  pnum = scaleunit.pnum;
  n = scaleunit.n;
  a_addr = scaleunit.a_addr;
  b_addr = scaleunit.b_addr;
  c_addr = scaleunit.c_addr;
  
  if(my_rank < pnum)
    {
      q = (int)sqrt(pnum);
      nq = n / q;
      x = my_rank % q;
      y = my_rank / q;
      
      catch(DMI_local_barrier_init(&local_barrier, scaleunit.barrier_addr));
      
      original_a = NULL;
      original_b = NULL;
      original_c = NULL;
      if(my_rank == 0)
        {
          original_a = (double*)my_malloc(n * n * sizeof(double));
          original_b = (double*)my_malloc(n * n * sizeof(double));
          original_c = (double*)my_malloc(n * n * sizeof(double));
          
          mrand_init(516);
          for(i = 0; i < pnum; i++)
            for(j = 0; j < nq; j++)
              for(k = 0; k < nq; k++)
                {
                  original_a[i * nq * nq + j * nq + k] = mrand_01();
                  original_b[i * nq * nq + j * nq + k] = mrand_01();
                  original_c[i * nq * nq + j * nq + k] = 0;
                }
          if(pnum * nq * nq != n * n) error();
          
#if 0
          {
            int i, j, k, x, y, z;
            double *a, *b, *c;
            
            for(i = 0; i < pnum; i++)
              for(j = 0; j < nq; j++)
                for(k = 0; k < nq; k++)
                  original_c[i * nq * nq + j * nq + k] = 0;
            
            for(x = 0; x < q; x++)
              for(y = 0; y < q; y++)
                for(stage = 0; stage < q; stage++)
                  {
                    z = (x + stage) % q;
                    
                    a = &original_a[(y * q + z) * nq * nq];
                    b = &original_b[(z * q + x) * nq * nq];
                    c = &original_c[(y * q + x) * nq * nq];
                    
                    for(i = 0; i < nq; i++)
                      for(k = 0; k < nq; k++)
                        for(j = 0; j < nq; j++)
                          {
                            c[i * nq + j] += a[i * nq + k] * b[k * nq + j];
                          }
                  }
            //print_matrix(original_a, n);
            //print_matrix(original_b, n);
            //print_matrix(original_c, n);
            sum = sumof_matrix(original_c, n);
            outn("# ans sum : %lf", sum);
          }
#endif
        }
      
      if(pnum == 1)
        {
          int i, j, k, x, y, z;
          double *a, *b, *c;
          
          time_lap(10);
          
          time_lap(60);
          for(x = 0; x < q; x++)
            for(y = 0; y < q; y++)
              for(z = 0; z < q; z++)
                {
                  a = &original_a[(y * q + z) * nq * nq];
                  b = &original_b[(z * q + x) * nq * nq];
                  c = &original_c[(y * q + x) * nq * nq];
                  
                  for(i = 0; i < nq; i++)
                    for(k = 0; k < nq; k++)
                      for(j = 0; j < nq; j++)
                        {
                          c[i * nq + j] += a[i * nq + k] * b[k * nq + j];
                        }
                }
          t_calc += time_diff(60);
          
          time_lap(11);
          print_calctime(time_ref(11) - time_ref(10), t_calc, &local_barrier, my_rank, pnum);
          
          //print_matrix(original_c, n);
          sum = sumof_matrix(original_c, n);
          outn("# ans sum : %lf", sum);
          
          my_free(original_c);
          my_free(original_b);
          my_free(original_a);
          return 0;
        }
      
      local_a = (double*)my_malloc(nq * nq * sizeof(double));
      local_b = (double*)my_malloc(nq * nq * sizeof(double));
      local_c = (double*)my_malloc(nq * nq * sizeof(double));
      
      for(i = 0; i < nq; i++)
        for(j = 0; j < nq; j++)
          local_c[i * nq + j] = 0;
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
      
      time_lap(1);
      
      if(my_rank == 0)
        {
          catch(DMI_write(a_addr, n * n * sizeof(double), original_a, DMI_EXCLUSIVE, NULL));
          catch(DMI_write(b_addr, n * n * sizeof(double), original_b, DMI_EXCLUSIVE, NULL));
        }
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
      
#if 1
      catch(DMI_write(a_addr + my_rank * nq * nq * sizeof(double), 0, NULL, DMI_EXCLUSIVE, NULL));
      
      catch(DMI_write(b_addr + my_rank * nq * nq * sizeof(double), 0, NULL, DMI_EXCLUSIVE, NULL));
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
#endif
      time_lap(2);
      
      for(stage = 0; stage < q; stage++)
        {
          /*
            catch(DMI_local_barrier_sync(&local_barrier, pnum));
          */
          
          if(my_rank == 0)
            {
              outn("stage : %d", stage);
            }
          
          z = (x + stage) % q;
          
          catch(DMI_read(a_addr + (y * q + z) * nq * nq * sizeof(double), nq * nq * sizeof(double), local_a, DMI_INVALIDATE, NULL));
          catch(DMI_read(b_addr + (z * q + x) * nq * nq * sizeof(double), nq * nq * sizeof(double), local_b, DMI_INVALIDATE, NULL));
          
          time_lap(60);
          for(i = 0; i < nq; i++)
            for(k = 0; k < nq; k++)
              for(j = 0; j < nq; j++)
                {
                  local_c[i * nq + j] += local_a[i * nq + k] * local_b[k * nq + j];
                }
          t_calc += time_diff(60);
        }
      
      time_lap(3);
      
      catch(DMI_write(c_addr + my_rank * nq * nq * sizeof(double), nq * nq * sizeof(double), local_c, DMI_PUT, NULL));
      
      catch(DMI_local_barrier_sync(&local_barrier, pnum));
      
      if(my_rank == 0)
        {
          catch(DMI_read(c_addr, n * n * sizeof(double), original_c, DMI_INVALIDATE, NULL));
          
          time_lap(4);
          
          //print_matrix(original_c, n);
          sum = sumof_matrix(original_c, n);
          outn("# sum : %lf", sum);
          
          outn("scatter=%.12lf calc=%.12lf gather=%.12lf", time_ref(2) - time_ref(1), time_ref(3) - time_ref(2), time_ref(4) - time_ref(3));
          
          my_free(original_c);
          my_free(original_b);
          my_free(original_a);
        }
      print_calctime(time_ref(4) - time_ref(1), t_calc, &local_barrier, my_rank, pnum);
      
      catch(DMI_local_barrier_destroy(&local_barrier));
      
      my_free(local_a);
      my_free(local_b);
      my_free(local_c);
    }
  return 0;
}

void print_matrix(double *matrix, int32_t n)
{
  int32_t i, j;
  
  for(i = 0; i < n; i++)
    {
      for(j = 0; j < n; j++)
        {
          out("%lf ", matrix[i * n + j]);
        }
      out("\n");
    }
  outn("***********");
  return;
}

double sumof_matrix(double *matrix, int32_t n)
{
  double sum;
  int32_t i, j;
  
  sum = 0;
  for(i = 0; i < n; i++)
    {
      for(j = 0; j < n; j++)
        {
          sum += matrix[i * n + j];
        }
    }
  return sum;
}

void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum)
{
  double t_sum, t_max, t_min;

  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_max, DMI_OP_MAX, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_min, DMI_OP_MIN, DMI_TYPE_DOUBLE));
  if(my_rank == 0)
    {
      outn("pnum=%d time=%.12lf calc_max=%.12lf calc_min=%.12lf calc_avg=%.12lf comm=%.12lf", 
           pnum, total, t_max, t_min, t_sum / pnum, total - t_max);
    }
  return;
}
