提交 8499b852 编写于 作者: L liuqi

Remove the if-clause at conv 3x3 opencl kernel.

上级 ad16a866
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
__kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin * kw * kh, cout/4 */ __read_only image2d_t filter, /* cout%4 * cin * kw * kh, cout/4 */
#ifdef BIAS #ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */ __read_only image2d_t bias, /* cout%4 * cout/4 */
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const int in_height, __private const int in_height,
...@@ -35,11 +35,34 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -35,11 +35,34 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
out4 = out0; out4 = out0;
#endif #endif
int w0 = out_w_blk - padding_left; #define DEFINE_IN_WIDTH(i) \
int w1 = w0 + out_w_blks; in_width##i[1] = in_width##i[0] + 1; \
int w2 = w1 + out_w_blks; in_width##i[2] = in_width##i[0] + 2; \
int w3 = w2 + out_w_blks; in_width##i[0] = (in_width##i[0] < 0 || in_width##i[0] >= in_width) ? (INT_MIN) : in_width##i[0]; \
int w4 = w3 + out_w_blks; in_width##i[1] = (in_width##i[1] < 0 || in_width##i[1] >= in_width) ? (INT_MIN) : in_width##i[1]; \
in_width##i[2] = (in_width##i[2] < 0 || in_width##i[2] >= in_width) ? (INT_MIN) : in_width##i[2];
int in_width0[3];
int in_width1[3];
int in_width2[3];
int in_width3[3];
int in_width4[3];
in_width0[0] = out_w_blk - padding_left;
in_width1[0] = in_width0[0] + out_w_blks;
in_width2[0] = in_width1[0] + out_w_blks;
in_width3[0] = in_width2[0] + out_w_blks;
in_width4[0] = in_width3[0] + out_w_blks;
DEFINE_IN_WIDTH(0);
DEFINE_IN_WIDTH(1);
DEFINE_IN_WIDTH(2);
DEFINE_IN_WIDTH(3);
DEFINE_IN_WIDTH(4);
#undef DEFINE_IN_WIDTH
const int batch_idx = out_hb / out_height; const int batch_idx = out_hb / out_height;
const int height_idx = out_hb % out_height; const int height_idx = out_hb % out_height;
...@@ -64,37 +87,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -64,37 +87,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
in_idx = in_ch_blk * in_width; in_idx = in_ch_blk * in_width;
in_width_idx = w0 + width_idx;
// Judge the width border for padding input. // Judge the width border for padding input.
if (in_width_idx < 0 || in_width_idx >= in_width) { in0 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width0[width_idx], in_hb[hb_idx]));
in0 = 0; in1 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width1[width_idx], in_hb[hb_idx]));
} else { in2 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width2[width_idx], in_hb[hb_idx]));
in0 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx])); in3 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width3[width_idx], in_hb[hb_idx]));
} in4 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width4[width_idx], in_hb[hb_idx]));
in_width_idx = w1 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in1 = 0;
} else {
in1 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in_width_idx = w2 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in2 = 0;
} else {
in2 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in_width_idx = w3 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in3 = 0;
} else {
in3 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
in_width_idx = w4 + width_idx;
if (in_width_idx < 0 || in_width_idx >= in_width) {
in4 = 0;
} else {
in4 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx]));
}
int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch; int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch;
weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk)); weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk));
...@@ -134,31 +132,32 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -134,31 +132,32 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
} }
const int out_x_base = out_ch_blk * out_width; const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk;
WRITE_IMAGET(output, WRITE_IMAGET(output,
(int2)(out_x_base + w0 + padding_left, out_hb), (int2)(out_x_base + w, out_hb),
out0); out0);
w1 += padding_left; w += out_w_blks;
if (w1 >= out_width) return; if (w >= out_width) return;
WRITE_IMAGET(output, WRITE_IMAGET(output,
(int2)(out_x_base + w1, out_hb), (int2)(out_x_base + w, out_hb),
out1); out1);
w2 += padding_left; w += out_w_blks;
if (w2 >= out_width) return; if (w >= out_width) return;
WRITE_IMAGET(output, WRITE_IMAGET(output,
(int2)(out_x_base + w2, out_hb), (int2)(out_x_base + w, out_hb),
out2); out2);
w3 += padding_left; w += out_w_blks;
if (w3 >= out_width) return; if (w >= out_width) return;
WRITE_IMAGET(output, WRITE_IMAGET(output,
(int2)(out_x_base + w3, out_hb), (int2)(out_x_base + w, out_hb),
out3); out3);
w4 += padding_left; w += out_w_blks;
if (w4 >= out_width) return; if (w >= out_width) return;
WRITE_IMAGET(output, WRITE_IMAGET(output,
(int2)(out_x_base + w4, out_hb), (int2)(out_x_base + w, out_hb),
out4); out4);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册