#include "dmi_api.h"

#define THREAD_TOTAL 256
#define THREAD_NUM 8

#define REAL_MIN -2
#define REAL_MAX 0.5
#define IMAG_MIN -1.25
#define IMAG_MAX 1.25

typedef struct scaleunit_t
{
  int32_t width;
  int32_t height;
  int32_t iter_max;
  int32_t job_num;
  int32_t print_flag;
  int64_t index_addr;
  int64_t buf_addr;
  int64_t barrier_addr;
}scaleunit_t;

typedef struct job_t
{
  int x;
  int y;
  int width;
  int height;
}job_t;

__thread double t_calc = 0;

int calc_point(double cx, double cy, int iter_max);
void print_image(int width, int height, int32_t *buf);
void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum);

void DMI_main(int argc, char **argv)
{
  scaleunit_t scaleunit;
  int width, height, job_num, iter_max, print_flag, init_node_num, thread_num;
  int64_t buf_addr, index_addr, scaleunit_addr, index, barrier_addr;
  
  if(argc != 8)
    {
      outn("usage : %s init_node_num thread_num width height job_num iter_max print_flag", argv[0]);
      error();
    }
  
  init_node_num = atoi(argv[1]);
  thread_num = atoi(argv[2]);
  width = atoi(argv[3]);
  height = atoi(argv[4]);
  job_num = atoi(argv[5]);
  iter_max = atoi(argv[6]);
  print_flag = atoi(argv[7]);
  if(height % job_num != 0)
    {
      errn("height % job_num != 0");
      error();
    }
  
  catch(DMI_mmap(&scaleunit_addr, sizeof(scaleunit_t), 1, NULL));
  catch(DMI_mmap(&buf_addr, (int64_t)width * height / job_num * sizeof(int32_t), job_num, NULL));
  catch(DMI_mmap(&index_addr, sizeof(int64_t), 1, NULL));
  catch(DMI_mmap(&barrier_addr, sizeof(DMI_barrier_t), 1, NULL));
  catch(DMI_barrier_init(barrier_addr));
  
  index = 0;
  catch(DMI_write(index_addr, sizeof(int64_t), &index, DMI_EXCLUSIVE, NULL));
  
  scaleunit.width = width;
  scaleunit.height = height;
  scaleunit.iter_max = iter_max;
  scaleunit.job_num = job_num;
  scaleunit.print_flag = print_flag;
  scaleunit.index_addr = index_addr;
  scaleunit.barrier_addr = barrier_addr;
  scaleunit.buf_addr = buf_addr;
  catch(DMI_write(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_EXCLUSIVE, NULL));
  
  catch(DMI_rescale(scaleunit_addr, init_node_num, thread_num));
  
  catch(DMI_barrier_destroy(barrier_addr));
  catch(DMI_munmap(index_addr, NULL));
  catch(DMI_munmap(buf_addr, NULL));
  catch(DMI_munmap(scaleunit_addr, NULL));
  return;
}

int32_t DMI_scaleunit(int my_rank, int pnum, int64_t scaleunit_addr)
{
  DMI_local_barrier_t local_barrier;
  scaleunit_t scaleunit;
  job_t job;
  int c, x, y;
  int64_t index;
  double cx, cy;
  int32_t *buf;
  
  catch(DMI_read(scaleunit_addr, sizeof(scaleunit_t), &scaleunit, DMI_GET, NULL));
  bind_to_cpu(my_rank % PROCNUM);
  
  catch(DMI_local_barrier_init(&local_barrier, scaleunit.barrier_addr));
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  time_lap(10);
  
  while(1)
    {
      catch(DMI_fad(scaleunit.index_addr, 1, &index, DMI_PUT, NULL));
      if(index >= scaleunit.job_num)
        {
          break;
        }
      outn("rank=%d job=%d", my_rank, index);
      
      time_lap(60);
      job.x = 0;
      job.y = scaleunit.height / scaleunit.job_num * index;
      job.width = scaleunit.width;
      job.height = scaleunit.height / scaleunit.job_num;
      
      buf = (int32_t*)my_malloc(job.height * job.width * sizeof(int32_t));
      
      c = 0;
      for(y = job.y; y < job.y + job.height; y++)
        {
          for(x = job.x; x < job.x + job.width; x++)
            {
              cx = REAL_MIN + x * ((REAL_MAX - REAL_MIN) / scaleunit.width);
              cy = IMAG_MIN + y * ((IMAG_MAX - IMAG_MIN) / scaleunit.height);
              buf[c++] = calc_point(cx, cy, scaleunit.iter_max);
            }
        }
      t_calc += time_diff(60);
      
      catch(DMI_write(scaleunit.buf_addr + job.y * job.width * sizeof(int32_t), job.height * job.width * sizeof(int32_t), buf, DMI_PUT, NULL));
      
      my_free(buf);
    }
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  if(my_rank == 0)
    {
      buf = (int32_t*)my_malloc(scaleunit.width * scaleunit.height * sizeof(int32_t));
      catch(DMI_read(scaleunit.buf_addr, scaleunit.width * scaleunit.height * sizeof(int32_t), buf, DMI_GET, NULL));
      if(scaleunit.print_flag)
        {
          print_image(scaleunit.width, scaleunit.height, buf);
        }
      my_free(buf);
    }
  time_lap(11);
  print_calctime(time_ref(11) - time_ref(10), t_calc, &local_barrier, my_rank, pnum);
  
  catch(DMI_local_barrier_sync(&local_barrier, pnum));
  catch(DMI_local_barrier_destroy(&local_barrier));
  return 0;
}

int calc_point(double cx, double cy, int iter_max)
{
  int iter, color, hi, r, g, b, tmp;
  double old_zx, old_zy, zx, zy, h, s, v, p, q, t, f, dr, dg, db;
  
  zx = 0;
  zy = 0;
  for(iter = 0; iter < iter_max; iter++)
    {
      old_zx = zx;
      old_zy = zy;
      
      zx = old_zx * old_zx - old_zy * old_zy + cx;
      zy = 2 * old_zx * old_zy + cy;
      
      if(zx * zx + zy * zy >= 4)
        {
          break;
        }
    }
  
  if(iter == iter_max)
    {
      color = 0;
    }
  else
    {
      tmp = iter % 360;
      tmp = tmp * tmp % 360;
      h = tmp * tmp % 360;
      s = 1;
      v = 1;
      
      hi = (int)(h / 60) % 6;
      f = h / 60 - hi;
      p = v * (1 - s);
      q = v * (1 - f * s);
      t = v * (1 - (1 - f) * s);
      dr = dg = db = 0;
      switch(hi)
        {
        case 0:
          dr = v; dg = t; db = p;
          break;
        case 1:
          dr = q; dg = v; db = p;
          break;
        case 2:
          dr = p; dg = v; db = t;
          break;
        case 3:
          dr = p; dg = q; db = v;
          break;
        case 4:
          dr = t; dg = p; db = v;
          break;
        case 5:
          dr = v; dg = p; db = q;
          break;
        default:
          error();
        }
      
      r = (int)(dr * 0xff);
      g = (int)(dg * 0xff);
      b = (int)(db * 0xff);
      
      color = (r << 16) | (g << 8) | b;
    }
  
  return color;
}

void print_image(int width, int height, int32_t *buf)
{
  int c, x, y, r, g, b, color;
  
  printf("P3\n");
  printf("%d %d\n", width, height);
  printf("%d\n", 0xff);
  
  c = 0;
  for(y = 0; y < height; y++)
    {
      for(x = 0; x < width; x++)
        {
          color = buf[y * height + x];
          r = (color << 8) >> 24;
          g = (color << 16) >> 24;
          b = (color << 24) >> 24;
          printf("%d %d %d ", (uint8_t)r, (uint8_t)g, (uint8_t)b);
          c++;
          if(c % 30 == 0)
            {
              printf("\n");
            }
        }
    }
  return;
}

void print_calctime(double total, double t, DMI_local_barrier_t *local_barrier, int my_rank, int pnum)
{
  double t_sum, t_max, t_min;

  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_sum, DMI_OP_SUM, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_max, DMI_OP_MAX, DMI_TYPE_DOUBLE));
  catch(DMI_local_barrier_allreduce(local_barrier, pnum, &t, &t_min, DMI_OP_MIN, DMI_TYPE_DOUBLE));
  if(my_rank == 0)
    {
      outn("pnum=%d time=%.12lf calc_max=%.12lf calc_min=%.12lf calc_avg=%.12lf comm=%.12lf", 
           pnum, total, t_max, t_min, t_sum / pnum, total - t_max);
    }
  return;
}
