diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index 30a5cdd7e5ea16b33a9bc6a58459b7f874e2e4ec..765582682a5261426a3c0a19292a5a5a2168775b 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -3,7 +3,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t filter, /* cout%4 * cin * kw * kh, cout/4 */ #ifdef BIAS - __read_only image2d_t bias, /* cout%4 * cout/4 */ + __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif __write_only image2d_t output, __private const int in_height, @@ -35,11 +35,34 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] out4 = out0; #endif - int w0 = out_w_blk - padding_left; - int w1 = w0 + out_w_blks; - int w2 = w1 + out_w_blks; - int w3 = w2 + out_w_blks; - int w4 = w3 + out_w_blks; +#define DEFINE_IN_WIDTH(i) \ + in_width##i[1] = in_width##i[0] + 1; \ + in_width##i[2] = in_width##i[0] + 2; \ + in_width##i[0] = (in_width##i[0] < 0 || in_width##i[0] >= in_width) ? (INT_MIN) : in_width##i[0]; \ + in_width##i[1] = (in_width##i[1] < 0 || in_width##i[1] >= in_width) ? (INT_MIN) : in_width##i[1]; \ + in_width##i[2] = (in_width##i[2] < 0 || in_width##i[2] >= in_width) ? (INT_MIN) : in_width##i[2]; + + int in_width0[3]; + int in_width1[3]; + int in_width2[3]; + int in_width3[3]; + int in_width4[3]; + in_width0[0] = out_w_blk - padding_left; + in_width1[0] = in_width0[0] + out_w_blks; + in_width2[0] = in_width1[0] + out_w_blks; + in_width3[0] = in_width2[0] + out_w_blks; + in_width4[0] = in_width3[0] + out_w_blks; + DEFINE_IN_WIDTH(0); + + DEFINE_IN_WIDTH(1); + + DEFINE_IN_WIDTH(2); + + DEFINE_IN_WIDTH(3); + + DEFINE_IN_WIDTH(4); + +#undef DEFINE_IN_WIDTH const int batch_idx = out_hb / out_height; const int height_idx = out_hb % out_height; @@ -64,37 +87,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] in_idx = in_ch_blk * in_width; - in_width_idx = w0 + width_idx; // Judge the width border for padding input. - if (in_width_idx < 0 || in_width_idx >= in_width) { - in0 = 0; - } else { - in0 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx])); - } - in_width_idx = w1 + width_idx; - if (in_width_idx < 0 || in_width_idx >= in_width) { - in1 = 0; - } else { - in1 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx])); - } - in_width_idx = w2 + width_idx; - if (in_width_idx < 0 || in_width_idx >= in_width) { - in2 = 0; - } else { - in2 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx])); - } - in_width_idx = w3 + width_idx; - if (in_width_idx < 0 || in_width_idx >= in_width) { - in3 = 0; - } else { - in3 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx])); - } - in_width_idx = w4 + width_idx; - if (in_width_idx < 0 || in_width_idx >= in_width) { - in4 = 0; - } else { - in4 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width_idx, in_hb[hb_idx])); - } + in0 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width0[width_idx], in_hb[hb_idx])); + in1 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width1[width_idx], in_hb[hb_idx])); + in2 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width2[width_idx], in_hb[hb_idx])); + in3 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width3[width_idx], in_hb[hb_idx])); + in4 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width4[width_idx], in_hb[hb_idx])); int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch; weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk)); @@ -134,31 +132,32 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] } const int out_x_base = out_ch_blk * out_width; + int w = out_w_blk; WRITE_IMAGET(output, - (int2)(out_x_base + w0 + padding_left, out_hb), + (int2)(out_x_base + w, out_hb), out0); - w1 += padding_left; - if (w1 >= out_width) return; + w += out_w_blks; + if (w >= out_width) return; WRITE_IMAGET(output, - (int2)(out_x_base + w1, out_hb), + (int2)(out_x_base + w, out_hb), out1); - w2 += padding_left; - if (w2 >= out_width) return; + w += out_w_blks; + if (w >= out_width) return; WRITE_IMAGET(output, - (int2)(out_x_base + w2, out_hb), + (int2)(out_x_base + w, out_hb), out2); - w3 += padding_left; - if (w3 >= out_width) return; + w += out_w_blks; + if (w >= out_width) return; WRITE_IMAGET(output, - (int2)(out_x_base + w3, out_hb), + (int2)(out_x_base + w, out_hb), out3); - w4 += padding_left; - if (w4 >= out_width) return; + w += out_w_blks; + if (w >= out_width) return; WRITE_IMAGET(output, - (int2)(out_x_base + w4, out_hb), + (int2)(out_x_base + w, out_hb), out4); }