提交 708aa21e 编写于 作者: 叶剑武

Merge branch 'opencl' into 'master'

Enhance concat opencl

See merge request !1152
...@@ -27,6 +27,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS ...@@ -27,6 +27,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
__read_only image2d_t input0, __read_only image2d_t input0,
__read_only image2d_t input1, __read_only image2d_t input1,
__private const int input0_chan, __private const int input0_chan,
__private const int input1_chan,
__write_only image2d_t output) { __write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0); const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1); const int width_idx = get_global_id(1);
...@@ -40,7 +41,9 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS ...@@ -40,7 +41,9 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
#endif #endif
const int width = global_size_dim1; const int width = global_size_dim1;
const int output_chan = input0_chan + input1_chan;
const int input0_chan_blk = (input0_chan + 3) >> 2; const int input0_chan_blk = (input0_chan + 3) >> 2;
const int output_chan_blk = (output_chan + 3) >> 2;
DATA_TYPE4 data = 0; DATA_TYPE4 data = 0;
#ifdef DIVISIBLE_FOUR #ifdef DIVISIBLE_FOUR
...@@ -54,7 +57,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS ...@@ -54,7 +57,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
(int2)(mad24((chan_blk_idx - input0_chan_blk), width, width_idx), hb_idx)); (int2)(mad24((chan_blk_idx - input0_chan_blk), width, width_idx), hb_idx));
} }
#else #else
if (chan_blk_idx + 1 < input0_chan_blk) { if (chan_blk_idx < input0_chan_blk - 1) {
data = READ_IMAGET(input0, data = READ_IMAGET(input0,
SAMPLER, SAMPLER,
(int2)(mad24(chan_blk_idx, width, width_idx), hb_idx)); (int2)(mad24(chan_blk_idx, width, width_idx), hb_idx));
...@@ -62,15 +65,22 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS ...@@ -62,15 +65,22 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
const int in_chan_idx = chan_blk_idx - input0_chan_blk; const int in_chan_idx = chan_blk_idx - input0_chan_blk;
DATA_TYPE4 data0 = READ_IMAGET(input1, DATA_TYPE4 data0 = READ_IMAGET(input1,
SAMPLER, SAMPLER,
(int2)(mad24(in_chan_idx, width, width_idx), hb_idx)); (int2)(mad24(in_chan_idx, width, width_idx),
DATA_TYPE4 data1 = READ_IMAGET(input1, hb_idx));
SAMPLER, DATA_TYPE4 data1 = 0;
(int2)(mad24((in_chan_idx + 1), width, width_idx), hb_idx)); if (((in_chan_idx + 1) << 2) < input1_chan) {
data1 = READ_IMAGET(input1,
SAMPLER,
(int2)(mad24((in_chan_idx + 1),
width,
width_idx), hb_idx));
}
data = stitch_vector(data0, data1, input0_chan % 4, true); data = stitch_vector(data0, data1, input0_chan % 4, true);
} else { } else { // if (chan_blk_idx == input0_chan_blk - 1)
DATA_TYPE4 data0 = READ_IMAGET(input0, DATA_TYPE4 data0 = READ_IMAGET(input0,
SAMPLER, SAMPLER,
(int2)(mad24(chan_blk_idx, width, width_idx), hb_idx)); (int2)(mad24(chan_blk_idx, width, width_idx),
hb_idx));
DATA_TYPE4 data1 = READ_IMAGET(input1, DATA_TYPE4 data1 = READ_IMAGET(input1,
SAMPLER, SAMPLER,
(int2)(width_idx, hb_idx)); (int2)(width_idx, hb_idx));
...@@ -110,29 +120,3 @@ __kernel void concat_channel_multi(OUT_OF_RANGE_PARAMS ...@@ -110,29 +120,3 @@ __kernel void concat_channel_multi(OUT_OF_RANGE_PARAMS
WRITE_IMAGET(output, (int2)(pos, hb_idx), data); WRITE_IMAGET(output, (int2)(pos, hb_idx), data);
} }
//__kernel void concat_width(__read_only image2d_t input0,
// __read_only image2d_t input1,
// __private const int input0_width,
// __write_only image2d_t output) {
// const int chan_blk_idx = get_global_id(0);
// const int width_idx = get_global_id(1);
// const int width = get_global_size(1);
// const int hb_idx = get_global_id(2);
//
// const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
//
// DATA_TYPE4 data = 0;
// if (width_idx < input0_width) {
// data = READ_IMAGET(input0,
// SAMPLER,
// (int2)(chan_blk_idx * width + width_idx, hb_idx));
// } else {
// data = READ_IMAGET(input1,
// SAMPLER,
// (int2)(chan_blk_idx * width + (width_idx - input0_width), hb_idx));
// }
//
// WRITE_IMAGET(output, (int2)(chan_blk_idx * width + width_idx, hb_idx), data);
//}
...@@ -101,6 +101,7 @@ MaceStatus Concat2(OpContext *context, ...@@ -101,6 +101,7 @@ MaceStatus Concat2(OpContext *context,
kernel->setArg(idx++, kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(input1->opencl_image()))); *(static_cast<const cl::Image2D *>(input1->opencl_image())));
kernel->setArg(idx++, static_cast<int32_t>(input0->dim(3))); kernel->setArg(idx++, static_cast<int32_t>(input0->dim(3)));
kernel->setArg(idx++, static_cast<int32_t>(input1->dim(3)));
kernel->setArg(idx++, kernel->setArg(idx++,
*(static_cast<cl::Image2D *>(output->opencl_image()))); *(static_cast<cl::Image2D *>(output->opencl_image())));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册