#include <common.h>

#ifdef FP16
#define MIN_VALUE -HALF_MAX
#else
#define MIN_VALUE -FLT_MAX
#endif

inline int calculate_avg_block_size(const int pool_size,
                                    const int pos_h,
                                    const int pos_w,
                                    const int h_size,
                                    const int w_size) {
  const int h_start = max(0, pos_h);
  const int w_start = max(0, pos_w);
  const int h_end = min(pos_h + pool_size, h_size);
  const int w_end = min(pos_w + pool_size, w_size);
  return (h_end - h_start) * (w_end - w_start);
}

// Supported data type: half/float
__kernel void pooling(__read_only image2d_t input,
                      __private const int in_height,
                      __private const int in_width,
                      __private const int out_height,
                      __private const int pad_top,
                      __private const int pad_left,
                      __private const int stride,
                      __private const int pooling_size,
                      __write_only image2d_t output) {
  const int out_chan_idx = get_global_id(0);
  const int out_width_idx = get_global_id(1);
  const int out_width = get_global_size(1);
  const int out_hb_idx = get_global_id(2);

  const int batch_idx = (out_hb_idx / out_height) * in_height;
  const int in_height_start = (out_hb_idx % out_height) * stride - pad_top;
  const int in_width_start = out_width_idx * stride - pad_left;
  const int in_channel_offset = out_chan_idx * in_width;


#ifdef POOL_AVG
  DATA_TYPE4 res = 0;
  for (int height = 0; height < pooling_size; ++height) {
    int in_height_idx = in_height_start + height;
    in_height_idx = select(batch_idx + in_height_idx,
                       -1,
                       (in_height_idx < 0 || in_height_idx >= in_height));
    for (int width = 0; width < pooling_size; ++width) {
      int in_width_idx = in_width_start + width;
      in_width_idx = select(in_channel_offset + in_width_idx,
                            -1,
                            (in_width_idx < 0 || in_width_idx >= in_width));

      DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(in_width_idx, in_height_idx));
      res = res + in;
    }
  }
  const int block_size = calculate_avg_block_size(pooling_size,
                                                  in_height_start, in_width_start,
                                                  in_height, in_width);
  res /= block_size;
#else
  DATA_TYPE4 res = (DATA_TYPE4)(MIN_VALUE);
  for (int height = 0; height < pooling_size; ++height) {
    int in_height_idx = in_height_start + height;
    in_height_idx = select(batch_idx + in_height_idx,
                           -1,
                           (in_height_idx < 0 || in_height_idx >= in_height));
    if (in_height_idx != -1) {
      for (int width = 0; width < pooling_size; ++width) {
        int in_width_idx = in_width_start + width;
        in_width_idx = select(in_channel_offset + in_width_idx,
                              -1,
                              (in_width_idx < 0 || in_width_idx >= in_width));

        if (in_width_idx != -1) {
          DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(in_width_idx, in_height_idx));
          res = fmax(res, in);
        }
      }
    }
  }
#endif

  WRITE_IMAGET(output, (int2)(out_chan_idx * out_width + out_width_idx, out_hb_idx), res);
}
