From 73ca2e006a0cdaba32089e564cf1f8bba66eaa06 Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Tue, 24 Mar 2020 13:00:57 +0800 Subject: [PATCH] [LITE][OPENCL][Image]support multi batch conv2d 3x3 5x5 7x7 ,open lws,test=develop (#3258) --- .gitignore | 2 + .../cl_kernel/image/conv2d_3x3_opt_kernel.cl | 406 +++++++++++++++--- .../cl_kernel/image/conv2d_5x5_opt_kernel.cl | 252 +++++++++++ .../cl_kernel/image/conv2d_7x7_opt_kernel.cl | 252 +++++++++++ lite/kernels/opencl/conv_image_compute.cc | 12 +- lite/kernels/opencl/conv_image_compute.h | 2 +- .../kernels/opencl/conv_image_compute_test.cc | 26 +- 7 files changed, 876 insertions(+), 76 deletions(-) diff --git a/.gitignore b/.gitignore index ed131bdbba..9823f8c945 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,5 @@ metal/paddle-mobile-demo/paddle-mobile-demo/Resources metal/paddle-mobile-demo/paddle-mobile-demo/Resources/images metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models metal/MobileNetDemo/MobileNetDemo/Resources + +build* diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl index d3a40272ad..79f3922e89 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl @@ -14,22 +14,22 @@ limitations under the License. */ #include -__kernel void conv2d_3x3_opt(__private const int item_ch, +__kernel void conv2d_3x3_opt(__private const int item_ch, __private const int item_w, - __private const int item_h, + __private const int item_h, __read_only image2d_t input_image, __read_only image2d_t filter_image, #if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, #endif - __write_only image2d_t output_image, + __write_only image2d_t output_image, __private const int stride, - __private const int pad, + __private const int pad, __private const int dilation, __private const int batch, __private const int in_ch, __private const int in_w, - __private const int in_h, + __private const int in_h, __private const int out_w, __private const int out_h) { @@ -61,7 +61,8 @@ __kernel void conv2d_3x3_opt(__private const int item_ch, #ifdef BIASE_CH CL_DTYPE4 output[5]; - output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0)); + output[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0)); output[1] = output[0]; output[2] = output[0]; output[3] = output[0]; @@ -70,23 +71,33 @@ __kernel void conv2d_3x3_opt(__private const int item_ch, #elif defined(BIASE_ELE) CL_DTYPE4 output[5]; - output[0] = - READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_w_base_id + out_w_id0, item_h_id)); + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id0, item_h_id)); if (out_w_id1 < out_w) { - output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, - (int2)(out_w_base_id + out_w_id1, item_h_id)); + output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id1, item_h_id)); } if (out_w_id2 < out_w) { - output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, - (int2)(out_w_base_id + out_w_id2, item_h_id)); + output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id2, item_h_id)); } if (out_w_id3 < out_w) { - output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, - (int2)(out_w_base_id + out_w_id3, item_h_id)); + output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id3, item_h_id)); } if (out_w_id4 < out_w) { - output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, - (int2)(out_w_base_id + out_w_id4, item_h_id)); + output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id4, item_h_id)); } #else CL_DTYPE4 output[5] = {0.0f}; @@ -109,54 +120,76 @@ __kernel void conv2d_3x3_opt(__private const int item_ch, int filter_w_val = ch * 3; for (int h = 0; h < 3; h++) { - int in_h_val = select(out_batch_id * in_h + in_h_id + h, -1, + int in_h_val = select(out_batch_id * in_h + in_h_id + h, + -1, (out_batch_id * in_h + in_h_id + h < 0 || out_batch_id * in_h + in_h_id + h >= in_h)); for (int w = 0; w < 3; w++) { - int in_w_val0 = select(in_w_base_id + in_w_id0 + w, -1, + int in_w_val0 = select(in_w_base_id + in_w_id0 + w, + -1, (in_w_id0 + w < 0 || in_w_id0 + w >= in_w)); - int in_w_val1 = select(in_w_base_id + in_w_id1 + w, -1, + int in_w_val1 = select(in_w_base_id + in_w_id1 + w, + -1, (in_w_id1 + w < 0 || in_w_id1 + w >= in_w)); - int in_w_val2 = select(in_w_base_id + in_w_id2 + w, -1, + int in_w_val2 = select(in_w_base_id + in_w_id2 + w, + -1, (in_w_id2 + w < 0 || in_w_id2 + w >= in_w)); - int in_w_val3 = select(in_w_base_id + in_w_id3 + w, -1, + int in_w_val3 = select(in_w_base_id + in_w_id3 + w, + -1, (in_w_id3 + w < 0 || in_w_id3 + w >= in_w)); - int in_w_val4 = select(in_w_base_id + in_w_id4 + w, -1, + int in_w_val4 = select(in_w_base_id + in_w_id4 + w, + -1, (in_w_id4 + w < 0 || in_w_id4 + w >= in_w)); - filter[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, - filter_image, sampler, + filter[0] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, (int2)(filter_w_val + w, filter_h_val0 + h)); // in_ch:0-3,out_ch:0 - filter[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, - filter_image, sampler, + filter[1] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, (int2)(filter_w_val + w, filter_h_val1 + h)); // in_ch:0-3,out_ch:1 - filter[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, - filter_image, sampler, + filter[2] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, (int2)(filter_w_val + w, filter_h_val2 + h)); // in_ch:0-3,out_ch:2 - filter[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, - filter_image, sampler, + filter[3] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, (int2)(filter_w_val + w, filter_h_val3 + h)); // in_ch:0-3,out_ch:3 - filter_trans[0] = (CL_DTYPE4)(filter[0].x, filter[1].x, filter[2].x, - filter[3].x); // in_ch:0,out_ch:0-3 - filter_trans[1] = (CL_DTYPE4)(filter[0].y, filter[1].y, filter[2].y, - filter[3].y); // in_ch:1,out_ch:0-3 - filter_trans[2] = (CL_DTYPE4)(filter[0].z, filter[1].z, filter[2].z, - filter[3].z); // in_ch:2,out_ch:0-3 - filter_trans[3] = (CL_DTYPE4)(filter[0].w, filter[1].w, filter[2].w, - filter[3].w); // in_ch:3,out_ch:0-3 - - input[0] = - READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val)); - input[1] = - READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val)); - input[2] = - READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val)); - input[3] = - READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val)); - input[4] = - READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val)); + filter_trans[0] = (CL_DTYPE4)(filter[0].x, + filter[1].x, + filter[2].x, + filter[3].x); // in_ch:0,out_ch:0-3 + filter_trans[1] = (CL_DTYPE4)(filter[0].y, + filter[1].y, + filter[2].y, + filter[3].y); // in_ch:1,out_ch:0-3 + filter_trans[2] = (CL_DTYPE4)(filter[0].z, + filter[1].z, + filter[2].z, + filter[3].z); // in_ch:2,out_ch:0-3 + filter_trans[3] = (CL_DTYPE4)(filter[0].w, + filter[1].w, + filter[2].w, + filter[3].w); // in_ch:3,out_ch:0-3 + + input[0] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val)); + input[1] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val)); + input[2] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val)); + input[3] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val)); + input[4] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val)); output[0] = mad(input[0].x, filter_trans[0], output[0]); output[1] = mad(input[1].x, filter_trans[0], output[1]); @@ -195,23 +228,278 @@ __kernel void conv2d_3x3_opt(__private const int item_ch, output[3] = activation_type4(output[3]); output[4] = activation_type4(output[4]); - WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id0, item_h_id), - output[0]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id0, item_h_id), + output[0]); if (out_w_id1 < out_w) { - WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id1, item_h_id), - output[1]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id1, item_h_id), + output[1]); } if (out_w_id2 < out_w) { - WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id2, item_h_id), - output[2]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id2, item_h_id), + output[2]); } if (out_w_id3 < out_w) { - WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id3, item_h_id), - output[3]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id3, item_h_id), + output[3]); } if (out_w_id4 < out_w) { - WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id4, item_h_id), - output[4]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id4, item_h_id), + output[4]); } } +// support batch > 1 +__kernel void conv2d_3x3_multi_batch(__private const int item_ch, + __private const int item_w, + __private const int item_h, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int batch, + __private const int in_ch, + __private const int in_w, + __private const int in_h, + __private const int out_w, + __private const int out_h) { + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + // item_id + const int item_ch_id = get_global_id(0); + const int item_w_id = get_global_id(1); + const int item_h_id = get_global_id(2); + + // out_width_id_per_blk and out_batch_id + int out_batch_id = item_h_id / in_h; + int out_w_base_id = item_ch_id * out_w; + int out_w_id0 = item_w_id; + int out_w_id1 = out_w_id0 + item_w; + int out_w_id2 = out_w_id1 + item_w; + int out_w_id3 = out_w_id2 + item_w; + int out_w_id4 = out_w_id3 + item_w; + + // in_width_id_per_blk and in_height_id_per_batch + int in_h_id = (item_h_id % out_h) * stride - pad; + int in_w_id0 = item_w_id * stride - pad; + int in_w_id1 = in_w_id0 + item_w * stride; + int in_w_id2 = in_w_id1 + item_w * stride; + int in_w_id3 = in_w_id2 + item_w * stride; + int in_w_id4 = in_w_id3 + item_w * stride; + +#ifdef BIASE_CH + + CL_DTYPE4 output[5]; + output[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0)); + output[1] = output[0]; + output[2] = output[0]; + output[3] = output[0]; + output[4] = output[0]; + +#elif defined(BIASE_ELE) + + CL_DTYPE4 output[5]; + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id0, item_h_id)); + if (out_w_id1 < out_w) { + output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id1, item_h_id)); + } + if (out_w_id2 < out_w) { + output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id2, item_h_id)); + } + if (out_w_id3 < out_w) { + output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id3, item_h_id)); + } + if (out_w_id4 < out_w) { + output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id4, item_h_id)); + } +#else + CL_DTYPE4 output[5] = {0.0f}; +#endif + + CL_DTYPE4 filter[4] = {0.0f}; + CL_DTYPE4 filter_trans[4] = {0.0f}; + CL_DTYPE4 input[5] = {0.0f}; + + int filter_h_val0 = item_ch_id * 4 * 3; + int filter_h_val1 = filter_h_val0 + 3; + int filter_h_val2 = filter_h_val1 + 3; + int filter_h_val3 = filter_h_val2 + 3; + + for (int ch = 0; ch < (in_ch + 3) / 4; ch++) { + int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0; + + const int in_w_base_id = mul24(ch, in_w); + + int filter_w_val = ch * 3; + + for (int h = 0; h < 3; h++) { + int in_h_val = select( + out_batch_id * in_h + in_h_id + h, + -1, + (out_batch_id * in_h + in_h_id + h < out_batch_id * in_h || + out_batch_id * in_h + in_h_id + h >= (out_batch_id + 1) * in_h)); + + for (int w = 0; w < 3; w++) { + int in_w_val0 = select(in_w_base_id + in_w_id0 + w, + -1, + (in_w_id0 + w < 0 || in_w_id0 + w >= in_w)); + int in_w_val1 = select(in_w_base_id + in_w_id1 + w, + -1, + (in_w_id1 + w < 0 || in_w_id1 + w >= in_w)); + int in_w_val2 = select(in_w_base_id + in_w_id2 + w, + -1, + (in_w_id2 + w < 0 || in_w_id2 + w >= in_w)); + int in_w_val3 = select(in_w_base_id + in_w_id3 + w, + -1, + (in_w_id3 + w < 0 || in_w_id3 + w >= in_w)); + int in_w_val4 = select(in_w_base_id + in_w_id4 + w, + -1, + (in_w_id4 + w < 0 || in_w_id4 + w >= in_w)); + + filter[0] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, filter_h_val0 + h)); // in_ch:0-3,out_ch:0 + filter[1] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, filter_h_val1 + h)); // in_ch:0-3,out_ch:1 + filter[2] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, filter_h_val2 + h)); // in_ch:0-3,out_ch:2 + filter[3] = READ_IMG_TYPE( + CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, filter_h_val3 + h)); // in_ch:0-3,out_ch:3 + + filter_trans[0] = (CL_DTYPE4)(filter[0].x, + filter[1].x, + filter[2].x, + filter[3].x); // in_ch:0,out_ch:0-3 + filter_trans[1] = (CL_DTYPE4)(filter[0].y, + filter[1].y, + filter[2].y, + filter[3].y); // in_ch:1,out_ch:0-3 + filter_trans[2] = (CL_DTYPE4)(filter[0].z, + filter[1].z, + filter[2].z, + filter[3].z); // in_ch:2,out_ch:0-3 + filter_trans[3] = (CL_DTYPE4)(filter[0].w, + filter[1].w, + filter[2].w, + filter[3].w); // in_ch:3,out_ch:0-3 + + input[0] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val)); + input[1] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val)); + input[2] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val)); + input[3] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val)); + input[4] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val)); + + output[0] = mad(input[0].x, filter_trans[0], output[0]); + output[1] = mad(input[1].x, filter_trans[0], output[1]); + output[2] = mad(input[2].x, filter_trans[0], output[2]); + output[3] = mad(input[3].x, filter_trans[0], output[3]); + output[4] = mad(input[4].x, filter_trans[0], output[4]); + + if (ch_surplus < 3) { + output[0] = mad(input[0].y, filter_trans[1], output[0]); + output[1] = mad(input[1].y, filter_trans[1], output[1]); + output[2] = mad(input[2].y, filter_trans[1], output[2]); + output[3] = mad(input[3].y, filter_trans[1], output[3]); + output[4] = mad(input[4].y, filter_trans[1], output[4]); + } + if (ch_surplus < 2) { + output[0] = mad(input[0].z, filter_trans[2], output[0]); + output[1] = mad(input[1].z, filter_trans[2], output[1]); + output[2] = mad(input[2].z, filter_trans[2], output[2]); + output[3] = mad(input[3].z, filter_trans[2], output[3]); + output[4] = mad(input[4].z, filter_trans[2], output[4]); + } + if (ch_surplus < 1) { + output[0] = mad(input[0].w, filter_trans[3], output[0]); + output[1] = mad(input[1].w, filter_trans[3], output[1]); + output[2] = mad(input[2].w, filter_trans[3], output[2]); + output[3] = mad(input[3].w, filter_trans[3], output[3]); + output[4] = mad(input[4].w, filter_trans[3], output[4]); + } + } + } + } + + output[0] = activation_type4(output[0]); + output[1] = activation_type4(output[1]); + output[2] = activation_type4(output[2]); + output[3] = activation_type4(output[3]); + output[4] = activation_type4(output[4]); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id0, item_h_id), + output[0]); + if (out_w_id1 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id1, item_h_id), + output[1]); + } + if (out_w_id2 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id2, item_h_id), + output[2]); + } + if (out_w_id3 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id3, item_h_id), + output[3]); + } + if (out_w_id4 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id4, item_h_id), + output[4]); + } +} diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl index 7d859a7b1c..4ed2e07202 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl @@ -233,6 +233,258 @@ __kernel void conv2d_5x5_opt(__private const int item_ch, output[3] = activation_type4(output[3]); output[4] = activation_type4(output[4]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id0, item_h_id), + output[0]); + if (out_w_id1 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id1, item_h_id), + output[1]); + } + if (out_w_id2 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id2, item_h_id), + output[2]); + } + if (out_w_id3 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id3, item_h_id), + output[3]); + } + if (out_w_id4 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id4, item_h_id), + output[4]); + } +} +// support batch > 1 +__kernel void conv2d_5x5_multi_batch(__private const int item_ch, + __private const int item_w, + __private const int item_h, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int batch, + __private const int in_ch, + __private const int in_w, + __private const int in_h, + __private const int out_w, + __private const int out_h) { + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + // filter + const int filter_w = 5; + const int filter_h = 5; + + // item_id + const int item_ch_id = get_global_id(0); + const int item_w_id = get_global_id(1); + const int item_h_id = get_global_id(2); + + // out_width_id_per_blk and out_batch_id + int out_batch_id = item_h_id / in_h; + int out_w_base_id = item_ch_id * out_w; + int out_w_id0 = item_w_id; + int out_w_id1 = out_w_id0 + item_w; + int out_w_id2 = out_w_id1 + item_w; + int out_w_id3 = out_w_id2 + item_w; + int out_w_id4 = out_w_id3 + item_w; + + // in_width_id_per_blk and in_height_id_per_batch + int in_h_id = (item_h_id % out_h) * stride - pad; + int in_w_id0 = item_w_id * stride - pad; + int in_w_id1 = in_w_id0 + item_w * stride; + int in_w_id2 = in_w_id1 + item_w * stride; + int in_w_id3 = in_w_id2 + item_w * stride; + int in_w_id4 = in_w_id3 + item_w * stride; + +#ifdef BIASE_CH + + CL_DTYPE4 output[5]; + output[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0)); + output[1] = output[0]; + output[2] = output[0]; + output[3] = output[0]; + output[4] = output[0]; + +#elif defined(BIASE_ELE) + + CL_DTYPE4 output[5]; + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id0, item_h_id)); + if (out_w_id1 < out_w) { + output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id1, item_h_id)); + } + if (out_w_id2 < out_w) { + output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id2, item_h_id)); + } + if (out_w_id3 < out_w) { + output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id3, item_h_id)); + } + if (out_w_id4 < out_w) { + output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id4, item_h_id)); + } +#else + CL_DTYPE4 output[5] = {0.0f}; +#endif + + CL_DTYPE4 filter[4] = {0.0f}; + CL_DTYPE4 filter_trans[4] = {0.0f}; + CL_DTYPE4 input[5] = {0.0f}; + + int filter_h_val0 = item_ch_id * 4 * filter_h; + int filter_h_val1 = filter_h_val0 + filter_h; + int filter_h_val2 = filter_h_val1 + filter_h; + int filter_h_val3 = filter_h_val2 + filter_h; + + for (int ch = 0; ch < (in_ch + 3) / 4; ch++) { + int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0; + + const int in_w_base_id = mul24(ch, in_w); + + int filter_w_val = ch * filter_w; + + for (int h = 0; h < filter_h; h++) { + int in_h_val = select( + out_batch_id * in_h + in_h_id + h, + -1, + (out_batch_id * in_h + in_h_id + h < out_batch_id * in_h || + out_batch_id * in_h + in_h_id + h >= (out_batch_id + 1) * in_h)); + + for (int w = 0; w < filter_w; w++) { + int in_w_val0 = select(in_w_base_id + in_w_id0 + w, + -1, + (in_w_id0 + w < 0 || in_w_id0 + w >= in_w)); + int in_w_val1 = select(in_w_base_id + in_w_id1 + w, + -1, + (in_w_id1 + w < 0 || in_w_id1 + w >= in_w)); + int in_w_val2 = select(in_w_base_id + in_w_id2 + w, + -1, + (in_w_id2 + w < 0 || in_w_id2 + w >= in_w)); + int in_w_val3 = select(in_w_base_id + in_w_id3 + w, + -1, + (in_w_id3 + w < 0 || in_w_id3 + w >= in_w)); + int in_w_val4 = select(in_w_base_id + in_w_id4 + w, + -1, + (in_w_id4 + w < 0 || in_w_id4 + w >= in_w)); + + filter[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val0 + h)); // in_ch:0-3,out_ch:0 + filter[1] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val1 + h)); // in_ch:0-3,out_ch:1 + filter[2] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val2 + h)); // in_ch:0-3,out_ch:2 + filter[3] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val3 + h)); // in_ch:0-3,out_ch:3 + + filter_trans[0] = (CL_DTYPE4)(filter[0].x, + filter[1].x, + filter[2].x, + filter[3].x); // in_ch:0,out_ch:0-3 + filter_trans[1] = (CL_DTYPE4)(filter[0].y, + filter[1].y, + filter[2].y, + filter[3].y); // in_ch:1,out_ch:0-3 + filter_trans[2] = (CL_DTYPE4)(filter[0].z, + filter[1].z, + filter[2].z, + filter[3].z); // in_ch:2,out_ch:0-3 + filter_trans[3] = (CL_DTYPE4)(filter[0].w, + filter[1].w, + filter[2].w, + filter[3].w); // in_ch:3,out_ch:0-3 + + input[0] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val)); + input[1] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val)); + input[2] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val)); + input[3] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val)); + input[4] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val)); + + output[0] = mad(input[0].x, filter_trans[0], output[0]); + output[1] = mad(input[1].x, filter_trans[0], output[1]); + output[2] = mad(input[2].x, filter_trans[0], output[2]); + output[3] = mad(input[3].x, filter_trans[0], output[3]); + output[4] = mad(input[4].x, filter_trans[0], output[4]); + + if (ch_surplus < 3) { + output[0] = mad(input[0].y, filter_trans[1], output[0]); + output[1] = mad(input[1].y, filter_trans[1], output[1]); + output[2] = mad(input[2].y, filter_trans[1], output[2]); + output[3] = mad(input[3].y, filter_trans[1], output[3]); + output[4] = mad(input[4].y, filter_trans[1], output[4]); + } + if (ch_surplus < 2) { + output[0] = mad(input[0].z, filter_trans[2], output[0]); + output[1] = mad(input[1].z, filter_trans[2], output[1]); + output[2] = mad(input[2].z, filter_trans[2], output[2]); + output[3] = mad(input[3].z, filter_trans[2], output[3]); + output[4] = mad(input[4].z, filter_trans[2], output[4]); + } + if (ch_surplus < 1) { + output[0] = mad(input[0].w, filter_trans[3], output[0]); + output[1] = mad(input[1].w, filter_trans[3], output[1]); + output[2] = mad(input[2].w, filter_trans[3], output[2]); + output[3] = mad(input[3].w, filter_trans[3], output[3]); + output[4] = mad(input[4].w, filter_trans[3], output[4]); + } + } + } + } + + output[0] = activation_type4(output[0]); + output[1] = activation_type4(output[1]); + output[2] = activation_type4(output[2]); + output[3] = activation_type4(output[3]); + output[4] = activation_type4(output[4]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id0, item_h_id), diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl index 2adc5a947a..d82f4b4c96 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl @@ -233,6 +233,258 @@ __kernel void conv2d_7x7_opt(__private const int item_ch, output[3] = activation_type4(output[3]); output[4] = activation_type4(output[4]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id0, item_h_id), + output[0]); + if (out_w_id1 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id1, item_h_id), + output[1]); + } + if (out_w_id2 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id2, item_h_id), + output[2]); + } + if (out_w_id3 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id3, item_h_id), + output[3]); + } + if (out_w_id4 < out_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, + output_image, + (int2)(out_w_base_id + out_w_id4, item_h_id), + output[4]); + } +} +// support batch > 1 +__kernel void conv2d_7x7_multi_batch(__private const int item_ch, + __private const int item_w, + __private const int item_h, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int batch, + __private const int in_ch, + __private const int in_w, + __private const int in_h, + __private const int out_w, + __private const int out_h) { + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + // filter + const int filter_w = 7; + const int filter_h = 7; + + // item_id + const int item_ch_id = get_global_id(0); + const int item_w_id = get_global_id(1); + const int item_h_id = get_global_id(2); + + // out_width_id_per_blk and out_batch_id + int out_batch_id = item_h_id / in_h; + int out_w_base_id = item_ch_id * out_w; + int out_w_id0 = item_w_id; + int out_w_id1 = out_w_id0 + item_w; + int out_w_id2 = out_w_id1 + item_w; + int out_w_id3 = out_w_id2 + item_w; + int out_w_id4 = out_w_id3 + item_w; + + // in_width_id_per_blk and in_height_id_per_batch + int in_h_id = (item_h_id % out_h) * stride - pad; + int in_w_id0 = item_w_id * stride - pad; + int in_w_id1 = in_w_id0 + item_w * stride; + int in_w_id2 = in_w_id1 + item_w * stride; + int in_w_id3 = in_w_id2 + item_w * stride; + int in_w_id4 = in_w_id3 + item_w * stride; + +#ifdef BIASE_CH + + CL_DTYPE4 output[5]; + output[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0)); + output[1] = output[0]; + output[2] = output[0]; + output[3] = output[0]; + output[4] = output[0]; + +#elif defined(BIASE_ELE) + + CL_DTYPE4 output[5]; + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id0, item_h_id)); + if (out_w_id1 < out_w) { + output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id1, item_h_id)); + } + if (out_w_id2 < out_w) { + output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id2, item_h_id)); + } + if (out_w_id3 < out_w) { + output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id3, item_h_id)); + } + if (out_w_id4 < out_w) { + output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, + bias, + sampler, + (int2)(out_w_base_id + out_w_id4, item_h_id)); + } +#else + CL_DTYPE4 output[5] = {0.0f}; +#endif + + CL_DTYPE4 filter[4] = {0.0f}; + CL_DTYPE4 filter_trans[4] = {0.0f}; + CL_DTYPE4 input[5] = {0.0f}; + + int filter_h_val0 = item_ch_id * 4 * filter_h; + int filter_h_val1 = filter_h_val0 + filter_h; + int filter_h_val2 = filter_h_val1 + filter_h; + int filter_h_val3 = filter_h_val2 + filter_h; + + for (int ch = 0; ch < (in_ch + 3) / 4; ch++) { + int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0; + + const int in_w_base_id = mul24(ch, in_w); + + int filter_w_val = ch * filter_w; + + for (int h = 0; h < filter_h; h++) { + int in_h_val = select( + out_batch_id * in_h + in_h_id + h, + -1, + (out_batch_id * in_h + in_h_id + h < out_batch_id * in_h || + out_batch_id * in_h + in_h_id + h >= (out_batch_id + 1) * in_h)); + + for (int w = 0; w < filter_w; w++) { + int in_w_val0 = select(in_w_base_id + in_w_id0 + w, + -1, + (in_w_id0 + w < 0 || in_w_id0 + w >= in_w)); + int in_w_val1 = select(in_w_base_id + in_w_id1 + w, + -1, + (in_w_id1 + w < 0 || in_w_id1 + w >= in_w)); + int in_w_val2 = select(in_w_base_id + in_w_id2 + w, + -1, + (in_w_id2 + w < 0 || in_w_id2 + w >= in_w)); + int in_w_val3 = select(in_w_base_id + in_w_id3 + w, + -1, + (in_w_id3 + w < 0 || in_w_id3 + w >= in_w)); + int in_w_val4 = select(in_w_base_id + in_w_id4 + w, + -1, + (in_w_id4 + w < 0 || in_w_id4 + w >= in_w)); + + filter[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val0 + h)); // in_ch:0-3,out_ch:0 + filter[1] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val1 + h)); // in_ch:0-3,out_ch:1 + filter[2] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val2 + h)); // in_ch:0-3,out_ch:2 + filter[3] = + READ_IMG_TYPE(CL_DTYPE_CHAR, + filter_image, + sampler, + (int2)(filter_w_val + w, + filter_h_val3 + h)); // in_ch:0-3,out_ch:3 + + filter_trans[0] = (CL_DTYPE4)(filter[0].x, + filter[1].x, + filter[2].x, + filter[3].x); // in_ch:0,out_ch:0-3 + filter_trans[1] = (CL_DTYPE4)(filter[0].y, + filter[1].y, + filter[2].y, + filter[3].y); // in_ch:1,out_ch:0-3 + filter_trans[2] = (CL_DTYPE4)(filter[0].z, + filter[1].z, + filter[2].z, + filter[3].z); // in_ch:2,out_ch:0-3 + filter_trans[3] = (CL_DTYPE4)(filter[0].w, + filter[1].w, + filter[2].w, + filter[3].w); // in_ch:3,out_ch:0-3 + + input[0] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val)); + input[1] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val)); + input[2] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val)); + input[3] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val)); + input[4] = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val)); + + output[0] = mad(input[0].x, filter_trans[0], output[0]); + output[1] = mad(input[1].x, filter_trans[0], output[1]); + output[2] = mad(input[2].x, filter_trans[0], output[2]); + output[3] = mad(input[3].x, filter_trans[0], output[3]); + output[4] = mad(input[4].x, filter_trans[0], output[4]); + + if (ch_surplus < 3) { + output[0] = mad(input[0].y, filter_trans[1], output[0]); + output[1] = mad(input[1].y, filter_trans[1], output[1]); + output[2] = mad(input[2].y, filter_trans[1], output[2]); + output[3] = mad(input[3].y, filter_trans[1], output[3]); + output[4] = mad(input[4].y, filter_trans[1], output[4]); + } + if (ch_surplus < 2) { + output[0] = mad(input[0].z, filter_trans[2], output[0]); + output[1] = mad(input[1].z, filter_trans[2], output[1]); + output[2] = mad(input[2].z, filter_trans[2], output[2]); + output[3] = mad(input[3].z, filter_trans[2], output[3]); + output[4] = mad(input[4].z, filter_trans[2], output[4]); + } + if (ch_surplus < 1) { + output[0] = mad(input[0].w, filter_trans[3], output[0]); + output[1] = mad(input[1].w, filter_trans[3], output[1]); + output[2] = mad(input[2].w, filter_trans[3], output[2]); + output[3] = mad(input[3].w, filter_trans[3], output[3]); + output[4] = mad(input[4].w, filter_trans[3], output[4]); + } + } + } + } + + output[0] = activation_type4(output[0]); + output[1] = activation_type4(output[1]); + output[2] = activation_type4(output[2]); + output[3] = activation_type4(output[3]); + output[4] = activation_type4(output[4]); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id0, item_h_id), diff --git a/lite/kernels/opencl/conv_image_compute.cc b/lite/kernels/opencl/conv_image_compute.cc index 3e356df9d3..d664e37150 100644 --- a/lite/kernels/opencl/conv_image_compute.cc +++ b/lite/kernels/opencl/conv_image_compute.cc @@ -142,9 +142,10 @@ void ConvImageCompute::PrepareForRun() { filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); impl_ = &ConvImageCompute::DepthwiseConv2d; - } else if (kernel_h == 3 && kernel_h == 3) { + } else if (kernel_w == 3 && kernel_h == 3) { // conv2d_3x3 - kernel_func_names_.push_back("conv2d_3x3_opt"); + kernel_func_names_.push_back(bs > 1 ? "conv2d_3x3_multi_batch" + : "conv2d_3x3_opt"); kernel_func_paths_.push_back("image/conv2d_3x3_opt_kernel.cl"); CLImageConverterFolder converter; @@ -174,7 +175,9 @@ void ConvImageCompute::PrepareForRun() { impl_ = &ConvImageCompute::Conv2d5x5; #else // conv2d_5x5_opt - kernel_func_names_.push_back("conv2d_5x5_opt"); + + kernel_func_names_.push_back(bs > 1 ? "conv2d_5x5_multi_batch" + : "conv2d_5x5_opt"); kernel_func_paths_.push_back("image/conv2d_5x5_opt_kernel.cl"); CLImageConverterFolder converter; @@ -207,7 +210,8 @@ void ConvImageCompute::PrepareForRun() { #else // conv2d_7x7 - kernel_func_names_.push_back("conv2d_7x7_opt"); + kernel_func_names_.push_back(bs > 1 ? "conv2d_7x7_multi_batch" + : "conv2d_7x7_opt"); kernel_func_paths_.push_back("image/conv2d_7x7_opt_kernel.cl"); CLImageConverterFolder converter; diff --git a/lite/kernels/opencl/conv_image_compute.h b/lite/kernels/opencl/conv_image_compute.h index b87cbb9f16..57e4b91e0a 100644 --- a/lite/kernels/opencl/conv_image_compute.h +++ b/lite/kernels/opencl/conv_image_compute.h @@ -59,7 +59,7 @@ class ConvImageCompute : public KernelLite event_{new cl::Event}; Tensor filter_gpu_image_; Tensor bias_gpu_image_; - bool use_lws{false}; + bool use_lws{true}; }; } // namespace opencl diff --git a/lite/kernels/opencl/conv_image_compute_test.cc b/lite/kernels/opencl/conv_image_compute_test.cc index a5fb196c84..5563265198 100644 --- a/lite/kernels/opencl/conv_image_compute_test.cc +++ b/lite/kernels/opencl/conv_image_compute_test.cc @@ -510,7 +510,7 @@ TEST(conv2d, compute_image2d_3x3) { const int dilation = 1; const int stride = 2; const int group = 1; - for (int batch_size = 1; batch_size < 2; ++batch_size) { + for (int batch_size = 1; batch_size < 4; ++batch_size) { for (int oc = 1; oc < 10; oc += 1) { // oc for (int ih = 5; ih < 9; ih += 1) { // ih int iw = ih; @@ -532,7 +532,7 @@ const int stride = 2; #else // big scale with group const int stride = 1; const int group = 32 / 1; - const int batch_size = 1; + const int batch_size = 2; const int ic = 32 / 1; const int ih = 112 / 1; const int iw = 112 / 1; @@ -558,7 +558,8 @@ const int stride = 2; PRECISION(kFP16), DATALAYOUT(kImageDefault)); ASSERT_FALSE(kernels.empty()); - CHECK(batch_size == 1) << "conv3x3 only supprt batch_size == 1"; + // CHECK(batch_size == 1) << "conv3x3 only supprt + // batch_size == 1"; auto kernel = std::move(kernels.front()); SHADOW_LOG << "created conv2d kernel"; @@ -886,15 +887,16 @@ TEST(conv2d, compute_image2d_5x5) { // int loop_cnt = 0; #ifdef LOOP_TEST - for (int batch_size = 1; batch_size < 2; ++batch_size) { - for (int oc = 1; oc < 10; oc += 1) { // oc - for (int ih = 5; ih < 9; ih += 1) { // ih + for (int batch_size = 1; batch_size < 4; ++batch_size) { + for (int oc = 1; oc < 5; oc += 1) { // oc + for (int ih = 5; ih < 8; ih += 1) { // ih int iw = ih; - for (int ic = 2; ic < 10; ic += 1) { // ic + for (int ic = 2; ic < 6; ic += 1) { // ic for (bool bias_flag : {true, false}) { - for (std::string relu_flag : {/*true,*/ "relu"}) { + for (std::string relu_flag : {"" + "relu"}) { #else - const int batch_size = 1; + const int batch_size = 2; const int oc = 1; const int ih = 5; const int iw = 5; @@ -1231,15 +1233,15 @@ TEST(conv2d, compute_image2d_7x7) { // int loop_cnt = 0; #ifdef LOOP_TEST - for (int batch_size = 1; batch_size < 2; ++batch_size) { - for (int oc = 1; oc < 10; oc += 1) { // oc + for (int batch_size = 1; batch_size < 4; ++batch_size) { + for (int oc = 1; oc < 6; oc += 1) { // oc for (int ih = 7; ih < 8; ih += 1) { // ih int iw = ih; for (int ic = 2; ic < 4; ic += 1) { // ic for (bool bias_flag : {false, true}) { for (std::string relu_flag : {"", "relu"}) { #else - const int batch_size = 1; + const int batch_size = 2; const int oc = 1; const int ih = 7; const int iw = 7; -- GitLab