reduce.cl 2.3 KB
Newer Older
L
liutuo 已提交
1 2
#include <common.h>

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#if REDUCE_TYPE == 1
#define INIT_REDUCE_VALUE (DATA_TYPE4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT}
#define REDUCE_VALUE(x, y) fmin(x, y)
#elif REDUCE_TYPE == 2  // MAX
#define INIT_REDUCE_VALUE (DATA_TYPE4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT}
#define REDUCE_VALUE(x, y) fmax(x, y)
#elif REDUCE_TYPE == 3  // PROD
#define INIT_REDUCE_VALUE (DATA_TYPE4){1, 1, 1, 1}
#define REDUCE_VALUE(x, y) (x * y)
#else  // MEAN or SUM
#define INIT_REDUCE_VALUE (DATA_TYPE4){0, 0, 0, 0}
#define REDUCE_VALUE(x, y) (x + y)
#endif


L
liutuo 已提交
18 19 20
__kernel void reduce(OUT_OF_RANGE_PARAMS
                     GLOBAL_WORK_GROUP_SIZE_DIM3
                     __read_only image2d_t input,
21 22
                     __private const int out_height,
                     __private const int out_width,
L
liutuo 已提交
23 24
                     __private const int in_height,
                     __private const int in_width,
25 26
                     __private const int org_height,
                     __private const int org_width,
L
liutuo 已提交
27 28
                     __private const int channel_blocks,
                     __write_only image2d_t output) {
29 30
  const int ow = get_global_id(0);
  const int oh = get_global_id(1);
L
liutuo 已提交
31
  const int bc = get_global_id(2);
L
liutuo 已提交
32
#ifndef NON_UNIFORM_WORK_GROUP
L
liutuo 已提交
33
  if (bc >= global_size_dim2)
L
liutuo 已提交
34 35
    return;
#endif
36

L
liutuo 已提交
37
  const int b = bc / channel_blocks;
38 39 40 41 42
  const int c = bc % channel_blocks;
  const int tile_w = in_width / out_width;
  const int tile_h = in_height / out_height;
  const int start_w = tile_w * ow;
  const int start_h = tile_h * oh;
L
liutuo 已提交
43

44 45 46 47
  const int size_w = select(tile_w, in_width - start_w, ow >= out_width - 1);
  const int size_h = select(tile_h, in_height - start_h, oh >= out_height - 1);
  const int end_h = start_h + size_h;
  const int end_w = start_w + size_w;
L
liutuo 已提交
48

49 50
  DATA_TYPE4 in;
  DATA_TYPE4 out = INIT_REDUCE_VALUE;
L
liutuo 已提交
51
#pragma unroll
52 53 54 55 56 57 58
  for (int h = start_h; h < end_h; ++h) {
    for (int w = start_w; w < end_w; ++w) {
      int pos_x = mad24(c, in_width, w);
      int pos_y = mad24(b, in_height, h);
      in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y));
      out = REDUCE_VALUE(out, in);
    }
L
liutuo 已提交
59 60
  }
#if REDUCE_TYPE == 0
61 62 63
  if (out_height == 1 && out_width == 1) {
    out = out / (org_height * org_width);
  }
L
liutuo 已提交
64 65
#endif

66 67 68
  int pos_x = mad24(c, out_width, ow);
  int pos_y = mad24(b, out_height, oh);
  WRITE_IMAGET(output, (int2)(pos_x, pos_y), out);
L
liutuo 已提交
69
}