提交 708aa21e 编写于 作者: 叶剑武

Merge branch 'opencl' into 'master'

Enhance concat opencl

See merge request !1152
......@@ -27,6 +27,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
__read_only image2d_t input0,
__read_only image2d_t input1,
__private const int input0_chan,
__private const int input1_chan,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
......@@ -40,7 +41,9 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
#endif
const int width = global_size_dim1;
const int output_chan = input0_chan + input1_chan;
const int input0_chan_blk = (input0_chan + 3) >> 2;
const int output_chan_blk = (output_chan + 3) >> 2;
DATA_TYPE4 data = 0;
#ifdef DIVISIBLE_FOUR
......@@ -54,7 +57,7 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
(int2)(mad24((chan_blk_idx - input0_chan_blk), width, width_idx), hb_idx));
}
#else
if (chan_blk_idx + 1 < input0_chan_blk) {
if (chan_blk_idx < input0_chan_blk - 1) {
data = READ_IMAGET(input0,
SAMPLER,
(int2)(mad24(chan_blk_idx, width, width_idx), hb_idx));
......@@ -62,15 +65,22 @@ __kernel void concat_channel(OUT_OF_RANGE_PARAMS
const int in_chan_idx = chan_blk_idx - input0_chan_blk;
DATA_TYPE4 data0 = READ_IMAGET(input1,
SAMPLER,
(int2)(mad24(in_chan_idx, width, width_idx), hb_idx));
DATA_TYPE4 data1 = READ_IMAGET(input1,
SAMPLER,
(int2)(mad24((in_chan_idx + 1), width, width_idx), hb_idx));
(int2)(mad24(in_chan_idx, width, width_idx),
hb_idx));
DATA_TYPE4 data1 = 0;
if (((in_chan_idx + 1) << 2) < input1_chan) {
data1 = READ_IMAGET(input1,
SAMPLER,
(int2)(mad24((in_chan_idx + 1),
width,
width_idx), hb_idx));
}
data = stitch_vector(data0, data1, input0_chan % 4, true);
} else {
} else { // if (chan_blk_idx == input0_chan_blk - 1)
DATA_TYPE4 data0 = READ_IMAGET(input0,
SAMPLER,
(int2)(mad24(chan_blk_idx, width, width_idx), hb_idx));
(int2)(mad24(chan_blk_idx, width, width_idx),
hb_idx));
DATA_TYPE4 data1 = READ_IMAGET(input1,
SAMPLER,
(int2)(width_idx, hb_idx));
......@@ -110,29 +120,3 @@ __kernel void concat_channel_multi(OUT_OF_RANGE_PARAMS
WRITE_IMAGET(output, (int2)(pos, hb_idx), data);
}
//__kernel void concat_width(__read_only image2d_t input0,
// __read_only image2d_t input1,
// __private const int input0_width,
// __write_only image2d_t output) {
// const int chan_blk_idx = get_global_id(0);
// const int width_idx = get_global_id(1);
// const int width = get_global_size(1);
// const int hb_idx = get_global_id(2);
//
// const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
//
// DATA_TYPE4 data = 0;
// if (width_idx < input0_width) {
// data = READ_IMAGET(input0,
// SAMPLER,
// (int2)(chan_blk_idx * width + width_idx, hb_idx));
// } else {
// data = READ_IMAGET(input1,
// SAMPLER,
// (int2)(chan_blk_idx * width + (width_idx - input0_width), hb_idx));
// }
//
// WRITE_IMAGET(output, (int2)(chan_blk_idx * width + width_idx, hb_idx), data);
//}
......@@ -101,6 +101,7 @@ MaceStatus Concat2(OpContext *context,
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(input1->opencl_image())));
kernel->setArg(idx++, static_cast<int32_t>(input0->dim(3)));
kernel->setArg(idx++, static_cast<int32_t>(input1->dim(3)));
kernel->setArg(idx++,
*(static_cast<cl::Image2D *>(output->opencl_image())));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册