From 84f64635a6d18c929fa38902b962a6fcf88370c8 Mon Sep 17 00:00:00 2001 From: liyin Date: Thu, 27 Jun 2019 14:47:20 +0800 Subject: [PATCH] Enhance concat opencl --- mace/ops/opencl/cl/concat.cl | 50 +++++++++++---------------------- mace/ops/opencl/image/concat.cc | 1 + 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/mace/ops/opencl/cl/concat.cl b/mace/ops/opencl/cl/concat.cl index 7f36c5b4..a4f4962e 100644 --- a/mace/ops/opencl/cl/concat.cl +++ b/mace/ops/opencl/cl/concat.cl @@ -27,6 +27,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS __read_only image2d_t input0, __read_only image2d_t input1, __private const int input0_chan, + __private const int input1_chan, __write_only image2d_t output) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); @@ -40,7 +41,9 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS #endif const int width = global_size_dim1; + const int output_chan = input0_chan + input1_chan; const int input0_chan_blk = (input0_chan + 3) >> 2; + const int output_chan_blk = (output_chan + 3) >> 2; DATA_TYPE4 data = 0; #ifdef DIVISIBLE_FOUR @@ -54,7 +57,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS (int2)(mad24((chan_blk_idx - input0_chan_blk), width, width_idx), hb_idx)); } #else - if (chan_blk_idx + 1 < input0_chan_blk) { + if (chan_blk_idx < input0_chan_blk - 1) { data = READ_IMAGET(input0, SAMPLER, (int2)(mad24(chan_blk_idx, width, width_idx), hb_idx)); @@ -62,15 +65,22 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS const int in_chan_idx = chan_blk_idx - input0_chan_blk; DATA_TYPE4 data0 = READ_IMAGET(input1, SAMPLER, - (int2)(mad24(in_chan_idx, width, width_idx), hb_idx)); - DATA_TYPE4 data1 = READ_IMAGET(input1, - SAMPLER, - (int2)(mad24((in_chan_idx + 1), width, width_idx), hb_idx)); + (int2)(mad24(in_chan_idx, width, width_idx), + hb_idx)); + DATA_TYPE4 data1 = 0; + if (((in_chan_idx + 1) << 2) < input1_chan) { + data1 = READ_IMAGET(input1, + SAMPLER, + (int2)(mad24((in_chan_idx + 1), + width, + width_idx), hb_idx)); + } data = stitch_vector(data0, data1, input0_chan % 4, true); - } else { + } else { // if (chan_blk_idx == input0_chan_blk - 1) DATA_TYPE4 data0 = READ_IMAGET(input0, SAMPLER, - (int2)(mad24(chan_blk_idx, width, width_idx), hb_idx)); + (int2)(mad24(chan_blk_idx, width, width_idx), + hb_idx)); DATA_TYPE4 data1 = READ_IMAGET(input1, SAMPLER, (int2)(width_idx, hb_idx)); @@ -110,29 +120,3 @@ __kernel void concat_channel_multi(OUT_OF_RANGE_PARAMS WRITE_IMAGET(output, (int2)(pos, hb_idx), data); } - -//__kernel void concat_width(__read_only image2d_t input0, -// __read_only image2d_t input1, -// __private const int input0_width, -// __write_only image2d_t output) { -// const int chan_blk_idx = get_global_id(0); -// const int width_idx = get_global_id(1); -// const int width = get_global_size(1); -// const int hb_idx = get_global_id(2); -// -// const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -// -// DATA_TYPE4 data = 0; -// if (width_idx < input0_width) { -// data = READ_IMAGET(input0, -// SAMPLER, -// (int2)(chan_blk_idx * width + width_idx, hb_idx)); -// } else { -// data = READ_IMAGET(input1, -// SAMPLER, -// (int2)(chan_blk_idx * width + (width_idx - input0_width), hb_idx)); -// } -// -// WRITE_IMAGET(output, (int2)(chan_blk_idx * width + width_idx, hb_idx), data); -//} - diff --git a/mace/ops/opencl/image/concat.cc b/mace/ops/opencl/image/concat.cc index f4433b43..ebdd53b4 100644 --- a/mace/ops/opencl/image/concat.cc +++ b/mace/ops/opencl/image/concat.cc @@ -101,6 +101,7 @@ MaceStatus Concat2(OpContext *context, kernel->setArg(idx++, *(static_cast(input1->opencl_image()))); kernel->setArg(idx++, static_cast(input0->dim(3))); + kernel->setArg(idx++, static_cast(input1->dim(3))); kernel->setArg(idx++, *(static_cast(output->opencl_image()))); -- GitLab