提交 8499b852 编写于 作者: L liuqi

Remove the if-clause at conv 3x3 opencl kernel.

上级 ad16a866
......@@ -3,7 +3,7 @@
__kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin * kw * kh, cout/4 */
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const int in_height,
......@@ -35,11 +35,34 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
out4 = out0;
#endif
int w0 = out_w_blk - padding_left;
int w1 = w0 + out_w_blks;
int w2 = w1 + out_w_blks;
int w3 = w2 + out_w_blks;
int w4 = w3 + out_w_blks;
#define DEFINE_IN_WIDTH(i) \
in_width##i[1] = in_width##i[0] + 1; \
in_width##i[2] = in_width##i[0] + 2; \
in_width##i[0] = (in_width##i[0] < 0 || in_width##i[0] >= in_width) ? (INT_MIN) : in_width##i[0]; \
in_width##i[1] = (in_width##i[1] < 0 || in_width##i[1] >= in_width) ? (INT_MIN) : in_width##i[1]; \
in_width##i[2] = (in_width##i[2] < 0 || in_width##i[2] >= in_width) ? (INT_MIN) : in_width##i[2];
int in_width0[3];
int in_width1[3];
int in_width2[3];
int in_width3[3];
int in_width4[3];
in_width0[0] = out_w_blk - padding_left;
in_width1[0] = in_width0[0] + out_w_blks;
in_width2[0] = in_width1[0] + out_w_blks;
in_width3[0] = in_width2[0] + out_w_blks;
in_width4[0] = in_width3[0] + out_w_blks;
DEFINE_IN_WIDTH(0);
DEFINE_IN_WIDTH(1);
DEFINE_IN_WIDTH(2);
DEFINE_IN_WIDTH(3);
DEFINE_IN_WIDTH(4);
#undef DEFINE_IN_WIDTH
const int batch_idx = out_hb / out_height;
const int height_idx = out_hb % out_height;
......@@ -64,37 +87,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
in_idx = in_ch_blk * in_width;
in_width_idx = w0 + width_idx;
// Judge the width border for padding input.
if (in_width_idx < 0 || in_width_idx >= in_width) {
in0 = 0;
} else {
in0 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in_width_idx = w1 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in1 = 0;
} else {
in1 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in_width_idx = w2 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in2 = 0;
} else {
in2 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in_width_idx = w3 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in3 = 0;
} else {
in3 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in_width_idx = w4 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in4 = 0;
} else {
in4 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in0 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width0[width_idx], in_hb[hb_idx]));
in1 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width1[width_idx], in_hb[hb_idx]));
in2 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width2[width_idx], in_hb[hb_idx]));
in3 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width3[width_idx], in_hb[hb_idx]));
in4 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width4[width_idx], in_hb[hb_idx]));
int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch;
weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk));
......@@ -134,31 +132,32 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
}
const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk;
WRITE_IMAGET(output,
(int2)(out_x_base + w0 + padding_left, out_hb),
(int2)(out_x_base + w, out_hb),
out0);
w1 += padding_left;
if (w1 >= out_width) return;
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w1, out_hb),
(int2)(out_x_base + w, out_hb),
out1);
w2 += padding_left;
if (w2 >= out_width) return;
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w2, out_hb),
(int2)(out_x_base + w, out_hb),
out2);
w3 += padding_left;
if (w3 >= out_width) return;
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w3, out_hb),
(int2)(out_x_base + w, out_hb),
out3);
w4 += padding_left;
if (w4 >= out_width) return;
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w4, out_hb),
(int2)(out_x_base + w, out_hb),
out4);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册