From 820eb6d4fc1287f052a8a65d885e493694cfb67c Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Fri, 22 Nov 2019 21:15:17 +0800 Subject: [PATCH] update conv 2-pad to 4-pad (#2404) * fix conv 2-pad to 4-pad * fix compute conv shape * fix pad, test=develop * change conv_depthwise_3x3s1_fp.cc name to conv3x3s1p01_depthwise_fp32.cc to distinguish between conv3x3s1_depthwise_fp32.cc * delete printf note in conv3x3s1, test=develop * delete printf note, test=develop * delete gem_sdot.h, test=develop it is coped from __gemm_sdot_meta_.h * update compute padding, test=develop * fix padding size, must be 2 or 4. test=develop * fix format in operators/conv_op.cc, test=develop * change #if 0 to #if 1, test=develop * put 2-pad to 4-pad in AttachImpl, test=develop * fix clang-format error inn tests/math/connv_compute_test, test=develop * fix x86 test result error, test=develop * add asymmetric padding test case in liite/tests/math/conv_compute.cc, test=develop * change paddings type to support dynamically modify, test=develop * fix x86 build error in connv_compute_test, test=develop * fix opencl build error, test=develop * fix oopencl build error, test=develop * fix opencl/conv_compute build error, test=develop * fix opencl/conv_compute build error, test=develop * fix format in kernels/opencl/conv_computte_ttest,test=develop * fix build error, test=develop fix build error in kernels/x86/conv_compute.h --- .../arm/math/conv3x3s1_direct_fp32.cc | 10 +- .../arm/math/conv3x3s1_direct_int8.cc | 5 +- .../arm/math/conv3x3s1px_depthwise_fp32.cc | 7 +- .../arm/math/conv3x3s2_direct_fp32.cc | 10 +- .../arm/math/conv3x3s2_direct_int8.cc | 10 +- .../arm/math/conv3x3s2px_depthwise_fp32.cc | 5 +- lite/backends/arm/math/conv_block_utils.h | 984 ++++++++---------- lite/backends/arm/math/conv_depthwise.h | 32 - lite/backends/arm/math/conv_impl.cc | 215 ++-- lite/backends/arm/math/conv_winograd_3x3.cc | 6 +- lite/backends/cuda/math/cudnn_conv.cc | 26 +- lite/backends/fpga/KD/pes/conv_process.hpp | 21 +- .../fpga/KD/pes/depthwise_conv_pe.hpp | 11 +- lite/kernels/arm/conv_compute.cc | 58 +- lite/kernels/arm/conv_depthwise.cc | 35 +- lite/kernels/arm/conv_gemmlike.h | 17 +- lite/kernels/arm/conv_transpose_compute.cc | 22 +- .../arm/conv_transpose_compute_test.cc | 19 +- lite/kernels/cuda/conv_compute.cc | 25 +- lite/kernels/cuda/conv_compute_test.cc | 3 +- lite/kernels/fpga/conv_compute.cc | 7 + lite/kernels/fpga/conv_compute_test.cc | 20 +- lite/kernels/npu/bridges/conv_op.cc | 23 +- lite/kernels/npu/bridges/conv_op_test.cc | 5 +- lite/kernels/npu/bridges/conv_transpose_op.cc | 13 +- .../npu/bridges/conv_transpose_op_test.cc | 3 +- lite/kernels/opencl/conv_compute.cc | 28 +- lite/kernels/opencl/conv_compute_test.cc | 14 +- .../opencl/depthwise_conv2d_compute.cc | 2 +- .../opencl/depthwise_conv2d_compute_test.cc | 3 +- lite/kernels/x86/conv_compute.h | 16 +- lite/kernels/x86/conv_compute_test.cc | 6 +- lite/kernels/xpu/bridges/conv_op.cc | 16 +- lite/kernels/xpu/bridges/conv_op_test.cc | 5 +- lite/operators/conv_op.cc | 28 +- lite/operators/conv_op.h | 20 +- lite/operators/conv_transpose_op.cc | 29 +- lite/operators/op_params.h | 14 +- lite/tests/math/conv_compute_test.cc | 212 ++-- lite/tests/math/conv_int8_compute_test.cc | 173 +-- .../tests/math/conv_transpose_compute_test.cc | 16 +- 41 files changed, 1135 insertions(+), 1039 deletions(-) mode change 100755 => 100644 lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp diff --git a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc index 6a1fa37681..b4972a1eca 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc @@ -35,9 +35,10 @@ size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param, auto dim_in = param.x->dims(); auto dim_out = param.output->dims(); const int threads = ctx->threads(); + auto paddings = *param.paddings; int llc_size = ctx->llc_size() / sizeof(float); - const int pad_w = param.paddings[1]; - const int pad_h = param.paddings[0]; + const int pad_w = paddings[2]; + const int pad_h = paddings[0]; int ow = dim_out[3]; int oh = dim_out[2]; int ic = dim_in[1]; @@ -74,9 +75,10 @@ void conv_3x3s1_direct_fp32(const float* i_data, ARMContext* ctx) { const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); + auto paddings = *param.paddings; - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round + 2; bool flag_relu = param.fuse_relu; diff --git a/lite/backends/arm/math/conv3x3s1_direct_int8.cc b/lite/backends/arm/math/conv3x3s1_direct_int8.cc index f966313e11..64e72bc441 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_int8.cc @@ -41,10 +41,11 @@ void conv_3x3s1_direct_int8(const int8_t* din, const operators::ConvParam& param, Context* ctx, const float* scale) { + auto paddings = *param.paddings; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; const int threads = ctx->threads(); int llc_size = ctx->llc_size() / 4; diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index 99aeea8bde..08e5efecd7 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -39,8 +39,11 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const operators::ConvParam& param, ARMContext* ctx) { int threads = ctx->threads(); - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + const int out_c_block = 4; const int out_h_kernel = 2; const int out_w_kernel = 4; diff --git a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc index 8260718a50..807135f57d 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc @@ -32,10 +32,11 @@ size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, ARMContext* ctx) { auto dim_in = param.x->dims(); auto dim_out = param.output->dims(); + auto paddings = *param.paddings; const int threads = ctx->threads(); int llc_size = ctx->llc_size() / sizeof(float); - const int pad_w = param.paddings[1]; - const int pad_h = param.paddings[0]; + const int pad_w = paddings[2]; + const int pad_h = paddings[0]; int ow = dim_out[3]; int oh = dim_out[2]; int ic = dim_in[1]; @@ -73,10 +74,11 @@ void conv_3x3s2_direct_fp32(const float* i_data, //! 3x3s2 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer + auto paddings = *param.paddings; const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); - const int pad_w = param.paddings[1]; - const int pad_h = param.paddings[0]; + const int pad_w = paddings[2]; + const int pad_h = paddings[0]; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round * 2 /*stride_w*/ + 1; bool flag_relu = param.fuse_relu; diff --git a/lite/backends/arm/math/conv3x3s2_direct_int8.cc b/lite/backends/arm/math/conv3x3s2_direct_int8.cc index 01b7a812eb..26829544bf 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_int8.cc @@ -46,10 +46,11 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! 3x3s2 int8 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer + auto paddings = *param.paddings; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; const int threads = ctx->threads(); int llc_size = ctx->llc_size() / 4; @@ -472,10 +473,11 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! 3x3s2 int8 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer + auto paddings = *param.paddings; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; const int threads = ctx->threads(); //! set 1/4 l2 cache int llc_size = ctx->llc_size() / 4; diff --git a/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc index 2d75323a96..9852c0f84e 100644 --- a/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc @@ -39,9 +39,10 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, const float* bias, const operators::ConvParam& param, ARMContext* ctx) { + auto paddings = *param.paddings; int threads = ctx->threads(); - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; const int out_c_block = 4; const int out_h_kernel = 1; const int out_w_kernel = 4; diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index b2d16d18d2..24b99692cc 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -722,7 +722,57 @@ inline bool write_to_output_c1_fp32(const float* din, } return true; } - +#ifdef __aarch64__ +#define NCHWC2_TRANS_FP32_COMPUTE \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ \ + "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ \ + "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ + +#define NCHWC2_TRANS_FP32_RELU \ + "fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ \ + "fmax v3.4s, v5.4s, v20.4s \n" /*relu*/ + +#define NCHWC2_TRANS_FP32_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + \ + "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q3, [%[doutc1r0]], #16 \n" /* store c2r0*/ \ + \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC2_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " \ + "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" \ + "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " \ + "c1r0, c1r1 \n" \ + "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " \ + "c1r2, c1r3 \n" \ + \ + "vswp d1, d2 @ swap data\n" + +#define NCHWC2_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" + +#define NCHWC2_TRANS_FP32_STORE \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " \ + "pointer\n" \ + \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \ + \ + "bne 1b @ jump to main loop\n" +#endif /*wirte result in outputs * input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] */ @@ -777,127 +827,41 @@ inline bool write_to_output_c2_fp32(const float* din, int cnt_loop = cnt; if (flag_relu) { #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - - "fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ - "fmax v3.4s, v5.4s, v20.4s \n" /*relu*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - - "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q3, [%[doutc1r0]], #16 \n" /* store c2r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); #else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " - "c1r0, c1r1 \n" - "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " - "c1r2, c1r3 \n" - - "vswp d1, d2 @ swap data\n" - - "vmax.f32 q0, q0, q15 @ relu\n" - "vmax.f32 q1, q1, q15 @ relu\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); #endif } else { #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "1: \n" /* main loop*/ - "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - - "str q4, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q5, [%[doutc1r0]], #16 \n" /* store c2r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5"); + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5"); #else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" - "1: @ main loop\n" - "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " - "c1r0, c1r1 \n" - "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " - "c1r2, c1r3 \n" - - "vswp d1, d2 @ swap data\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); #endif } } @@ -922,6 +886,70 @@ inline bool write_to_output_c2_fp32(const float* din, return true; } +#ifdef __aarch64__ +#define NCHWC4_TRANS_FP32_COMPUTE \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ \ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ \ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ \ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ \ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ \ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + +#define NCHWC4_TRANS_FP32_RELU \ + "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ \ + "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ \ + "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ \ + "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ + +#define NCHWC4_TRANS_FP32_STORE \ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ \ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC4_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" \ + "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " \ + "\n" \ + "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " \ + "\n" \ + \ + "vswp d1, d4 @ swap data\n" \ + "vswp d3, d6 @ swap data\n" + +#define NCHWC4_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" \ + "vmax.f32 q2, q2, q15 @ relu\n" \ + "vmax.f32 q3, q3, q15 @ relu\n" + +#define NCHWC4_TRANS_FP32_STORE \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \ + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" \ + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + \ + "bne 1b @ jump to main loop\n" +#endif /*wirte result in outputs * input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] */ @@ -983,176 +1011,79 @@ inline bool write_to_output_c4_fp32(const float* din, int cnt_loop = cnt; if (flag_relu) { #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ - "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ - "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ - "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v16", - "v17", - "v18", - "v19", - "v20"); + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); #else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " - "\n" - "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " - "\n" - - "vswp d1, d4 @ swap data\n" - "vswp d3, d6 @ swap data\n" - - "vmax.f32 q0, q0, q15 @ relu\n" - "vmax.f32 q1, q1, q15 @ relu\n" - "vmax.f32 q2, q2, q15 @ relu\n" - "vmax.f32 q3, q3, q15 @ relu\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); #endif } else { #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v8", - "v9", - "v10", - "v11", - "v16", - "v17", - "v18", - "v19"); + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v8", + "v9", + "v10", + "v11", + "v16", + "v17", + "v18", + "v19"); #else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "1: @ main loop\n" - "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " - "\n" - "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " - "\n" - - "vswp d1, d4 @ swap data\n" - "vswp d3, d6 @ swap data\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3"); + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3"); #endif } } @@ -1182,6 +1113,120 @@ inline bool write_to_output_c4_fp32(const float* din, return true; } +#ifdef __aarch64__ +#define NCHWC8_TRANS_FP32_COMPUTE \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ \ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ \ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ \ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ \ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + \ + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ \ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ \ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ \ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + \ + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ \ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ \ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ \ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ \ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + \ + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ \ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ \ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ \ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ \ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + +#define NCHWC8_TRANS_FP32_RELU \ + "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ \ + "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ \ + "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ \ + "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ \ + \ + "fmax v8.4s, v8.4s, v20.4s \n" /*relu*/ \ + "fmax v9.4s, v9.4s, v20.4s \n" /*relu*/ \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ + +#define NCHWC8_TRANS_FP32_STORE \ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ \ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ \ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ \ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ \ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ \ + \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC8_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" \ + "vtrn.32 q0, q2 @ trans q0, q2 \n" \ + "vtrn.32 q4, q6 @ trans q4, q6 \n" \ + "vswp.32 d1, d8 @ swap d1, d8 \n" \ + "vswp.32 d5, d12 @ swap d5, d12\n" \ + \ + "vtrn.32 q1, q3 @ trans q1, q3 \n" \ + "vtrn.32 q5, q7 @ trans q5, q7 \n" \ + "vswp.32 d3, d10 @ swap d3, d10\n" \ + "vswp.32 d7, d14 @ swap d7, d14\n" + +#define NCHWC8_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" \ + "vmax.f32 q2, q2, q15 @ relu\n" \ + "vmax.f32 q3, q3, q15 @ relu\n" \ + \ + "vmax.f32 q4, q4, q15 @ relu\n" \ + "vmax.f32 q5, q5, q15 @ relu\n" \ + "vmax.f32 q6, q6, q15 @ relu\n" \ + "vmax.f32 q7, q7, q15 @ relu\n" + +#define NCHWC8_TRANS_FP32_STORE \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " \ + "pointer\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + \ + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " \ + "pointer\n" \ + \ + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" \ + \ + "bne 1b @ jump to main loop\n" + +#endif /*wirte result in outputs * input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] */ @@ -1261,158 +1306,54 @@ inline bool write_to_output_c8_fp32(const float* din, if (cnt > 0) { int cnt_loop = cnt; #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ - "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ - "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ - "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ - - "fmax v8.4s, v8.4s, v20.4s \n" /*relu*/ - "fmax v9.4s, v9.4s, v20.4s \n" /*relu*/ - "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ - "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); #else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "vmax.f32 q0, q0, q15 @ relu\n" - "vmax.f32 q1, q1, q15 @ relu\n" - "vmax.f32 q2, q2, q15 @ relu\n" - "vmax.f32 q3, q3, q15 @ relu\n" - - "vmax.f32 q4, q4, q15 @ relu\n" - "vmax.f32 q5, q5, q15 @ relu\n" - "vmax.f32 q6, q6, q15 @ relu\n" - "vmax.f32 q7, q7, q15 @ relu\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4", "q15"); + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q15"); #endif } if (we > width) { @@ -1468,138 +1409,53 @@ inline bool write_to_output_c8_fp32(const float* din, if (cnt > 0) { int cnt_loop = cnt; #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); #else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "1: @ main loop\n" - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4"); + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4"); #endif } if (we > width) { diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index 1a23982cd5..b6c3478880 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -85,38 +85,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, bool flag_relu, ARMContext* ctx); -void conv_depthwise_3x3p0_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_3x3p1_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - template void conv_depthwise_3x3s1_int8(Dtype* dout, const int8_t* din, diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 8618baf286..dc68e65f42 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -107,29 +107,35 @@ void im2col(const Dtype* data_im, int width, int kernel_h, int kernel_w, - int pad_h, - int pad_w, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, int stride_h, int stride_w, int dilation_h, int dilation_w, Dtype* data_col) { const int output_h = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + (height + pad_top + pad_bottom - (dilation_h * (kernel_h - 1) + 1)) / + stride_h + + 1; const int output_w = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + (width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) / + stride_w + + 1; const int channel_size = height * width; for (int channel = channels; channel--; data_im += channel_size) { for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { - int input_row = -pad_h + kernel_row * dilation_h; + int input_row = -pad_top + kernel_row * dilation_h; for (int output_rows = output_h; output_rows; output_rows--) { if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { for (int output_cols = output_w; output_cols; output_cols--) { *(data_col++) = 0; } } else { - int input_col = -pad_w + kernel_col * dilation_w; + int input_col = -pad_left + kernel_col * dilation_w; for (int output_col = output_w; output_col; output_col--) { if (is_a_ge_zero_and_a_lt_b(input_col, width)) { *(data_col++) = data_im[input_row * width + input_col]; @@ -361,6 +367,9 @@ void conv_im2col_gemm(const float* i_data, float* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + auto paddings = *param.paddings; + auto dilations = *param.dilations; //! use gemv when the output channel size = 1 for (int b = 0; b < num; ++b) { // dC @@ -378,12 +387,14 @@ void conv_im2col_gemm(const float* i_data, win, kernel_h, kernel_w, - param.paddings[0], - param.paddings[1], + paddings[0], + paddings[1], + paddings[2], + paddings[3], param.strides[0], param.strides[1], - param.dilations[0], - param.dilations[1], + dilations[0], + dilations[1], dB); if (n == 1) { @@ -435,14 +446,16 @@ void conv_im2col_gemm_int8(const int8_t* i_data, const float* scale) { int group = param.groups; auto filter_dims = param.filter->dims(); + auto paddings = *param.paddings; + auto dilations = *param.dilations; int kernel_h = filter_dims[2]; int kernel_w = filter_dims[3]; int stride_h = param.strides[0]; int stride_w = param.strides[1]; - int dila_h = param.dilations[0]; - int dila_w = param.dilations[1]; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int dila_h = dilations[0]; + int dila_w = dilations[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; const int m = oc / group; const int n = oh * ow; const int k = ic * kernel_h * kernel_w / group; @@ -483,7 +496,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data, kernel_h, kernel_w, pad_h, + paddings[1], pad_w, + paddings[3], stride_h, stride_w, dila_h, @@ -563,90 +578,83 @@ void conv_depthwise_3x3_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; - if (pad_w != pad_h) { - LOG(FATAL) << "fp32 depthwise conv3x3 pad_w: " << pad_w - << ", pad_h: " << pad_h << " must be equal"; - return; - } + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; int stride = param.strides[1]; int pad = pad_w; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; - if (stride == 1 && pad < 2) { // support pad = [0, 1] - conv_depthwise_3x3s1_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else if (stride == 2 && pad < 2) { // support pad = [0, 1] - conv_depthwise_3x3s2_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else { - LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride - << " or pad(<2): " << pad << " unsupported"; - } -#if 0 - if (pad == 1) { - conv_depthwise_3x3p1_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - stride, - flag_bias, - flag_relu, - ctx); - } else if (pad == 0 && h_in > 2) { - conv_depthwise_3x3p0_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - stride, - flag_bias, - flag_relu, - ctx); + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + if (stride == 1) { + if (pads_equal && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] + conv_depthwise_3x3s1_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else { + conv_3x3s1_depthwise_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + param, + ctx); + } + + } else if (stride == 2) { + if (pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] + conv_depthwise_3x3s2_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else { + conv_3x3s2_depthwise_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + param, + ctx); + } } else { - LOG(FATAL) << "unsupport this type 3x3 dw conv"; + LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride << " unsupported"; } -#endif } void conv_depthwise_5x5_fp32(const void* din, @@ -663,7 +671,8 @@ void conv_depthwise_5x5_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad = param.paddings[1]; + auto paddings = *param.paddings; + int pad = paddings[0]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -719,8 +728,9 @@ void conv_depthwise_3x3_int8_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -777,8 +787,9 @@ void conv_depthwise_3x3_int8_int8(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -835,8 +846,9 @@ void conv_depthwise_5x5_int8_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -876,8 +888,9 @@ void conv_depthwise_5x5_int8_int8(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; diff --git a/lite/backends/arm/math/conv_winograd_3x3.cc b/lite/backends/arm/math/conv_winograd_3x3.cc index 87b08f6310..894b946a32 100644 --- a/lite/backends/arm/math/conv_winograd_3x3.cc +++ b/lite/backends/arm/math/conv_winograd_3x3.cc @@ -37,9 +37,9 @@ void conv_winograd3x3(const float* din, const operators::ConvParam& param, ARMContext* ctx) { int threads = ctx->threads(); - - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[1]; int size_in_channel = win * hin; int size_out_channel = wout * hout; bool flag_relu = param.fuse_relu; diff --git a/lite/backends/cuda/math/cudnn_conv.cc b/lite/backends/cuda/math/cudnn_conv.cc index 72ed3951f6..a4f33f467f 100644 --- a/lite/backends/cuda/math/cudnn_conv.cc +++ b/lite/backends/cuda/math/cudnn_conv.cc @@ -31,6 +31,9 @@ bool CudnnConv2D::create(const operators::ConvParam& param, auto o_dims = param.output->dims(); int batch = x_dims[0]; + auto paddings = *param.paddings; + auto dilations = *param.dilations; + int iw = x_dims[3]; // nchw int ih = x_dims[2]; int ic = x_dims[1]; @@ -41,10 +44,10 @@ bool CudnnConv2D::create(const operators::ConvParam& param, int kh = w_dims[2]; int sw = param.strides[1]; int sh = param.strides[0]; - int pw = param.paddings[1]; - int ph = param.paddings[0]; - int dw = param.dilations[1]; - int dh = param.dilations[0]; + int pw = paddings[2]; + int ph = paddings[0]; + int dw = dilations[1]; + int dh = dilations[0]; CHECK(ic % param.groups == 0) << "The conv input channel shoud be divide group number."; @@ -133,8 +136,8 @@ bool CudnnConv2D::create(const operators::ConvParam& param, this->fwd_algo_ = algo_cache.GetAlgorithm(x_dims.Vectorize(), w_dims.Vectorize(), param.strides, - param.paddings, - param.dilations, + *param.paddings, + *param.dilations, 0, search_func); @@ -311,12 +314,15 @@ bool CudnnConv2DInt8::create(const operators::ConvParam& param, int kw = w_dims[2]; int kh = w_dims[1]; + auto paddings = *param.paddings; + auto dilations = *param.dilations; + int sw = param.strides[1]; int sh = param.strides[0]; - int pw = param.paddings[1]; - int ph = param.paddings[0]; - int dw = param.dilations[1]; - int dh = param.dilations[0]; + int pw = paddings[2]; + int ph = paddings[0]; + int dw = dilations[1]; + int dh = dilations[0]; std::vector weight_scale = param.weight_scale; float input_scale = param.input_scale; diff --git a/lite/backends/fpga/KD/pes/conv_process.hpp b/lite/backends/fpga/KD/pes/conv_process.hpp index fd17218d06..23332b422d 100644 --- a/lite/backends/fpga/KD/pes/conv_process.hpp +++ b/lite/backends/fpga/KD/pes/conv_process.hpp @@ -294,10 +294,17 @@ inline void split_filter_num(const ConvParam& c_param) { args.image.channels = input->shape().channel(); args.image.width = input->shape().width(); args.image.height = input->shape().height(); - args.image.pad_width = param.paddings[1]; + auto paddings = *param.padding; + args.image.pad_width = param.paddings[2]; args.image.pad_height = param.paddings[0]; args.output.address = out_address; args.output.scale_address = out_scale_address; + bool pad_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + if (!pad_equal) { + LOG(FATA) << "This pad not support ! " << paddings[0] << ", " + << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; + } param.splitParams().push_back(conv_param); } } @@ -372,10 +379,18 @@ inline void split_channel(const ConvParam& c_param) { args.image.channels = conv_param->input.shape().channel(); args.image.width = conv_param->input.shape().width(); args.image.height = conv_param->input.shape().height(); - args.image.pad_width = param.paddings[1]; - args.image.pad_height = param.paddings[0]; + auto paddings = *param.paddings; + args.image.pad_width = paddings[2]; + args.image.pad_height = paddings[0]; + args.output.address = conv_param->output.mutableData(); args.output.scale_address = conv_param->output.scale(); + bool pad_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + if (!pad_equal) { + LOG(FATA) << "This pad not support ! " << paddings[0] << ", " + << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; + } param.splitParams().push_back(conv_param); } } diff --git a/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp b/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp old mode 100755 new mode 100644 index 9d7b9b544b..f86806102d --- a/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp +++ b/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp @@ -61,14 +61,21 @@ class DepthwiseConvPE : public PE { args.image.channels = input->shape().channel(); args.image.height = input->shape().height(); args.image.width = input->shape().width(); - args.image.pad_width = param.paddings[0]; - args.image.pad_height = param.paddings[1]; + auto paddings = *param.paddings; + args.image.pad_width = param.paddings[2]; + args.image.pad_height = param.paddings[0]; args.image.scale_address = input->scale(); args.output.address = output->data(); args.output.scale_address = output->scale(); args.out_width = param.output->shape().width(); args.out_height = param.output->shape().height(); args.sub_conv_num = 1; + bool pad_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + if (!pad_equal) { + LOG(FATA) << "This pad not support ! " << paddings[0] << ", " + << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; + } param.args = args; inplace_.relu_enable = param_.relu.enabled; diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index ebb96e21d5..799e8e2122 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -32,13 +32,17 @@ void ConvCompute::PrepareForRun() { auto w_dims = param.filter->dims(); auto& ctx = this->ctx_->template As(); + auto paddings = *param.paddings; + auto dilations = *param.dilations; int ic = w_dims[1] * param.groups; int oc = w_dims[0]; int kh = w_dims[2]; // oihw int kw = w_dims[3]; - int pad = param.paddings[0]; + int pad = paddings[0]; int stride = param.strides[0]; + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); int chin = param.x->dims()[1]; int hin = param.x->dims()[2]; int win = param.x->dims()[3]; @@ -46,16 +50,18 @@ void ConvCompute::PrepareForRun() { int hout = param.output->dims()[2]; int wout = param.output->dims()[3]; - bool kps_equal = (param.paddings[0] == param.paddings[1]) && - (param.strides[0] == param.strides[1]) && (kw == kh); - bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); + bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); + + bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh); + bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2)); - bool flag_dw_5x5 = - (kw == 5 && stride == 1) || (kw == 5 && stride == 2 && pad == 2); + bool flag_dw_5x5 = pads_all_equal && ((kw == 5 && stride == 1) || + (kw == 5 && stride == 2 && pad == 2)); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; /// select conv impl - if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { + if (param.groups == ic && ic == oc && kps_equal && pads_equal && + no_dilation && flag_dw) { /// dw conv impl impl_ = new DepthwiseConv; VLOG(3) << "invoking dw conv"; @@ -92,22 +98,29 @@ void ConvCompute::PrepareForRun() { auto& ctx = this->ctx_->template As(); + auto paddings = *param.paddings; + auto dilations = *param.dilations; + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); int ic = param.groups * w_dims[1]; int oc = w_dims[0]; int kh = w_dims[2]; // oihw int kw = w_dims[3]; - int ph = param.paddings[1]; - int pw = param.paddings[0]; + int ph = paddings[0]; + int pw = paddings[2]; int sh = param.strides[1]; int sw = param.strides[0]; + bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); - bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); - bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); - bool flag_dw_5x5 = (kw == 5 && sw == 1); + bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); + bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); + bool flag_dw_5x5 = pads_all_equal && + ((kw == 5 && sw == 1) || (kw == 5 && sw == 2 && pw == 2)); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; - if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { + if (param.groups == ic && ic == oc && kps_equal && pads_equal && + no_dilation && flag_dw) { impl_ = new DepthwiseConv; VLOG(3) << "Run DepthwiseConv Int8"; } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && @@ -130,23 +143,30 @@ void ConvCompute::PrepareForRun() { auto w_dims = param.filter->dims(); auto& ctx = this->ctx_->template As(); + auto paddings = *param.paddings; + auto dilations = *param.dilations; + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); int ic = w_dims[1] * param.groups; int oc = w_dims[0]; int kh = w_dims[2]; // oihw int kw = w_dims[3]; - int ph = param.paddings[1]; - int pw = param.paddings[0]; + int ph = paddings[0]; + int pw = paddings[2]; int sh = param.strides[1]; int sw = param.strides[0]; + bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); - bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); - bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); - bool flag_dw_5x5 = (kw == 5 && sw == 1); + bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); + bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); + bool flag_dw_5x5 = pads_all_equal && + ((kw == 5 && sw == 1) || (kw == 5 && sw == 2 && pw == 2)); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; - if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { + if (param.groups == ic && ic == oc && kps_equal && pads_equal && + no_dilation && flag_dw) { impl_ = new DepthwiseConv; VLOG(3) << "Run DepthwiseConv Int8"; } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index 6a20d607e3..e2eaef51dd 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -31,19 +31,28 @@ void DepthwiseConv::PrepareForRun() { // select dw conv kernel if (kw == 3) { VLOG(5) << "invoke 3x3 dw conv fp32"; - // trans weights - constexpr int cblock = 4; - auto oc = w_dims[0]; - auto kh = w_dims[2]; - auto cround = ROUNDUP(oc, cblock); - weights_.Resize({cround, 1, kh, kw}); - // auto w_data = weights_.mutable_data(); - // auto w_data_in = param.filter->data(); - // lite::arm::math::conv_trans_weights_numc( - // w_data_in, w_data, oc, 1, cblock, kh * kw); - impl_ = lite::arm::math::conv_depthwise_3x3_fp32; - flag_trans_weights_ = false; - // flag_trans_weights_ = true; + auto paddings = *param.paddings; + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + + if (pads_equal && paddings[0] == paddings[2] && + (paddings[0] == 0 || paddings[0] == 1)) { + impl_ = lite::arm::math::conv_depthwise_3x3_fp32; + flag_trans_weights_ = false; + } else { + // trans weights + constexpr int cblock = 4; + auto oc = w_dims[0]; + auto kh = w_dims[2]; + auto cround = ROUNDUP(oc, cblock); + weights_.Resize({cround, 1, kh, kw}); + auto w_data = weights_.mutable_data(); + auto w_data_in = param.filter->data(); + lite::arm::math::conv_trans_weights_numc( + w_data_in, w_data, oc, 1, cblock, kh * kw); + impl_ = lite::arm::math::conv_depthwise_3x3_fp32; + flag_trans_weights_ = true; + } } else if (kw == 5) { VLOG(5) << "invoke 5x5 dw conv fp32"; impl_ = lite::arm::math::conv_depthwise_5x5_fp32; diff --git a/lite/kernels/arm/conv_gemmlike.h b/lite/kernels/arm/conv_gemmlike.h index e00b8de6f4..5e59eb8d17 100644 --- a/lite/kernels/arm/conv_gemmlike.h +++ b/lite/kernels/arm/conv_gemmlike.h @@ -52,12 +52,19 @@ class GemmLikeConv : public KernelLite { int oc = o_dims[1]; int kw = w_dims[3]; int kh = w_dims[2]; + + auto paddings = *param.paddings; + auto dilations = *param.dilations; + int sw = param.strides[1]; int sh = param.strides[0]; - int pw = param.paddings[1]; - int ph = param.paddings[0]; - int dw = param.dilations[1]; - int dh = param.dilations[0]; + int pw = paddings[2]; + int ph = paddings[0]; + int dw = dilations[1]; + int dh = dilations[0]; + + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); int m = oc / param.groups; int k = ic * kh * kw / param.groups; @@ -66,7 +73,7 @@ class GemmLikeConv : public KernelLite { bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh); bool ks_equal = (sw == sh) && (kw == kh); //! select conv gemmlike kernel - if (kw == 1 && sw == 1 && pw == 0 && kps_equal) { + if (kw == 1 && sw == 1 && pw == 0 && kps_equal && pads_equal) { //! 1x1s1p0 gemmlike conv flag_1x1gemm_ = true; } else { diff --git a/lite/kernels/arm/conv_transpose_compute.cc b/lite/kernels/arm/conv_transpose_compute.cc index 5a18499c85..53e50b8800 100644 --- a/lite/kernels/arm/conv_transpose_compute.cc +++ b/lite/kernels/arm/conv_transpose_compute.cc @@ -76,19 +76,27 @@ void Conv2DTransposeCompute::Run() { bool fuse_relu = param.fuse_relu; bool flag_bias = (param.bias != nullptr); + auto paddings = *param.paddings; + auto dilations = *param.dilations; + int m = chout * kw * kh / group; int n = hin * win; int k = chin / group; + + bool pads_equal = + (paddings[0] == paddings[1]) && (paddings[2] == paddings[3]); + int group_size_in = win * hin * chin / group; int group_size_out = wout * hout * chout / group; int group_size_coldata = m * n; + + bool pads_all_qual = pads_equal && (paddings[0] == paddings[2]); int hblock = lite::arm::math::get_hblock(&ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int group_size_weights = ((m_roundup * k + 15) / 16) * 16; bool flag_1x1s1p1 = (kw == 1) && (kh == 1) && (param.strides[0] == 1) && - (param.strides[1] == 1) && (param.paddings[0] == 0) && - (param.paddings[1] == 0) && (param.dilations[0] == 1) && - (param.dilations[1] == 1); + (param.strides[1] == 1) && pads_all_qual && + (dilations[0] == 1) && (dilations[1] == 1); ctx.ExtendWorkspace(sizeof(float) * group * m * n); auto din = param.x->data(); @@ -129,12 +137,12 @@ void Conv2DTransposeCompute::Run() { wout, kh, kw, - param.paddings[0], - param.paddings[1], + paddings[0], + paddings[2], param.strides[0], param.strides[1], - param.dilations[0], - param.dilations[1], + dilations[0], + dilations[1], dout_batch); } if (flag_bias) { diff --git a/lite/kernels/arm/conv_transpose_compute_test.cc b/lite/kernels/arm/conv_transpose_compute_test.cc index 298c651d9f..53c5543aa1 100644 --- a/lite/kernels/arm/conv_transpose_compute_test.cc +++ b/lite/kernels/arm/conv_transpose_compute_test.cc @@ -194,15 +194,18 @@ void conv2d_transpose_compute_ref(const operators::ConvParam& param) { } int group = param.groups; + auto paddings = *param.paddings; + auto dilations = *param.dilations; + int kernel_h = param.filter->dims()[2]; int kernel_w = param.filter->dims()[3]; int stride_h = param.strides[0]; int stride_w = param.strides[1]; - int dila_h = param.dilations[0]; - int dila_w = param.dilations[1]; + int dila_h = dilations[0]; + int dila_w = dilations[1]; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; bool flag_bias = (param.bias != nullptr); bool flag_relu = param.fuse_relu; @@ -332,10 +335,14 @@ TEST(conv2d_transpose_arm, compute) { param.bias = &bias; } param.fuse_relu = flag_relu; - param.paddings = std::vector({padding, padding}); + std::vector paddings = { + padding, padding, padding, padding}; param.strides = std::vector({stride, stride}); + std::vector dilations = {dilation, dilation}; + param.paddings = + std::make_shared>(paddings); param.dilations = - std::vector({dilation, dilation}); + std::make_shared>(dilations); param.groups = group; conv2d_transpose.SetParam(param); conv2d_transpose.Launch(); diff --git a/lite/kernels/cuda/conv_compute.cc b/lite/kernels/cuda/conv_compute.cc index eea81602dd..468ed0cbd0 100644 --- a/lite/kernels/cuda/conv_compute.cc +++ b/lite/kernels/cuda/conv_compute.cc @@ -21,10 +21,14 @@ namespace lite { namespace kernels { namespace cuda { -inline int ConvOutputSize( - int input_size, int filter_size, int dilation, int padding, int stride) { +inline int ConvOutputSize(int input_size, + int filter_size, + int dilation, + int pad_left, + int pad_right, + int stride) { const int dkernel = dilation * (filter_size - 1) + 1; - int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + int output_size = (input_size + pad_left + pad_right - dkernel) / stride + 1; CHECK_GT_OR_FALSE(output_size, 0); return output_size; @@ -50,11 +54,15 @@ void ConvComputeInt8::PrepareForRun() { const auto filter_dims = param.filter->dims(); std::vector output_shape({in_dims[0]}); + auto paddings = *param.paddings; + auto dilations = *param.dilations; + for (size_t i = 0; i < param.strides.size(); ++i) { output_shape.push_back(ConvOutputSize(in_dims[i + 1], filter_dims[i + 1], - param.dilations[i], - param.paddings[i], + dilations[i], + paddings[2 * i], + paddings[2 * i + 1], param.strides[i])); } output_shape.push_back(filter_dims[0]); @@ -71,12 +79,15 @@ void ConvComputeInt8::Run() { const auto in_dims = param.x->dims(); const auto filter_dims = param.filter->dims(); std::vector output_shape({in_dims[0]}); + auto paddings = *param.paddings; + auto dilations = *param.dilations; for (size_t i = 0; i < param.strides.size(); ++i) { output_shape.push_back(ConvOutputSize(in_dims[i + 1], filter_dims[i + 1], - param.dilations[i], - param.paddings[i], + dilations[i], + paddings[2 * i], + paddings[2 * i + 1], param.strides[i])); } output_shape.push_back(filter_dims[0]); diff --git a/lite/kernels/cuda/conv_compute_test.cc b/lite/kernels/cuda/conv_compute_test.cc index 05175a0deb..1216c99051 100644 --- a/lite/kernels/cuda/conv_compute_test.cc +++ b/lite/kernels/cuda/conv_compute_test.cc @@ -41,7 +41,8 @@ TEST(conv_compute, fp32) { act_param.Leaky_relu_alpha = 0.1; operators::ConvParam param; param.activation_param = act_param; - param.paddings = {1, 1}; + std::vector pads = {1, 1, 1, 1}; + param.paddings = std::make_shared>(pads); param.groups = 1; Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu; diff --git a/lite/kernels/fpga/conv_compute.cc b/lite/kernels/fpga/conv_compute.cc index 3e06e103bb..8bc171dd67 100644 --- a/lite/kernels/fpga/conv_compute.cc +++ b/lite/kernels/fpga/conv_compute.cc @@ -36,8 +36,15 @@ void ConvCompute::PrepareForRun() { conv_param.filter = param.filter->ZynqTensor(); conv_param.groups = param.groups; conv_param.strides = param.strides; + auto paddings = *param.paddings; conv_param.paddings = param.paddings; conv_param.dilations = param.dilations; + bool pad_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + if (!pad_equal) { + LOG(FATA) << "This pad not support ! " << paddings[0] << ", " << paddings[1] + << ", " << paddings[2] << ", " << paddings[3]; + } fill_scale_bias_const(&conv_param); conv_param.bias()->copyFrom(param.bias->ZynqTensor()); conv_param.relu.enabled = param.fuse_relu; diff --git a/lite/kernels/fpga/conv_compute_test.cc b/lite/kernels/fpga/conv_compute_test.cc index f166974cc9..1e05c1fa0c 100644 --- a/lite/kernels/fpga/conv_compute_test.cc +++ b/lite/kernels/fpga/conv_compute_test.cc @@ -141,13 +141,15 @@ void conv_compute_ref(const operators::ConvParam& param) { int group = param.groups; int kernel_w = param.filter->dims()[2]; int kernel_h = param.filter->dims()[3]; + + auto paddings = *param.paddings; + auto dilations = *para.dilations; int stride_w = param.strides[0]; int stride_h = param.strides[1]; - int dila_w = param.dilations[0]; - int dila_h = param.dilations[1]; - - int pad_w = param.paddings[0]; - int pad_h = param.paddings[1]; + int dila_w = dilations[0]; + int dila_h = dilations[1]; + int pad_w = paddings[2]; + int pad_h = paddings[0]; bool flag_bias = (param.bias != nullptr); bool flag_relu = param.fuse_relu; @@ -277,10 +279,14 @@ TEST(conv_fpga, compute) { param.bias = &bias; } param.fuse_relu = flag_relu; - param.paddings = std::vector({padding, padding}); + std::vector paddings = { + padding, padding, padding, padding}; param.strides = std::vector({stride, stride}); + std::vector dilations = {dilation, dilation}; + param.paddings = + std::make_shared>(paddings); param.dilations = - std::vector({dilation, dilation}); + std::make_shared>(dilations); param.groups = group; conv.SetParam(param); conv.Launch(); diff --git a/lite/kernels/npu/bridges/conv_op.cc b/lite/kernels/npu/bridges/conv_op.cc index 32f4d511d5..6a8db88472 100644 --- a/lite/kernels/npu/bridges/conv_op.cc +++ b/lite/kernels/npu/bridges/conv_op.cc @@ -42,9 +42,9 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, auto bs = input_dims[0]; auto ic = input_dims[1]; auto oc = filter_dims[0]; - CHECK_EQ(input_dims.size(), 4); - CHECK_EQ(output_dims.size(), 4); - CHECK_EQ(filter_dims.size(), 4); + CHECK_EQ(input_dims.size(), 4L); + CHECK_EQ(output_dims.size(), 4L); + CHECK_EQ(filter_dims.size(), 4L); CHECK_EQ(output_dims[0], bs); CHECK_EQ(output_dims[1], oc); auto strides = op_info->GetAttr>("strides"); @@ -52,9 +52,16 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, auto groups = op_info->GetAttr("groups"); auto dilations = op_info->GetAttr>("dilations"); auto fuse_relu = op_info->GetAttr("fuse_relu"); - CHECK_EQ(strides.size(), 2); - CHECK_EQ(paddings.size(), 2); - CHECK_EQ(dilations.size(), 2); + CHECK_EQ(strides.size(), 2L); + CHECK_EQ(paddings.size(), 4L); + CHECK_EQ(dilations.size(), 2L); + + bool pad_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + if (!pad_equal) { + LOG(FATA) << "This pad not support ! " << paddings[0] << ", " << paddings[1] + << ", " << paddings[2] << ", " << paddings[3]; + } // check depthwise mode, and decide whether use ConvolutionDepthwise Op bool use_depthwise_conv = @@ -134,7 +141,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, depthwise_conv_node->set_attr_pad_mode(5); // VALID depthwise_conv_node->set_attr_group(groups); depthwise_conv_node->set_attr_pad(ge::AttrValue::LIST_INT( - {paddings[0], paddings[0], paddings[1], paddings[1]})); + {paddings[0], paddings[0], paddings[2], paddings[2]})); depthwise_conv_node->set_attr_dilation( ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); depthwise_conv_node->set_attr_stride( @@ -161,7 +168,7 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, common_conv_node->set_attr_pad_mode(0); // NOTSET common_conv_node->set_attr_group(groups); common_conv_node->set_attr_pad(ge::AttrValue::LIST_INT( - {paddings[0], paddings[0], paddings[1], paddings[1]})); + {paddings[0], paddings[0], paddings[2], paddings[2]})); common_conv_node->set_attr_dilation( ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); common_conv_node->set_attr_stride( diff --git a/lite/kernels/npu/bridges/conv_op_test.cc b/lite/kernels/npu/bridges/conv_op_test.cc index 26309aa9e2..909061d2ba 100644 --- a/lite/kernels/npu/bridges/conv_op_test.cc +++ b/lite/kernels/npu/bridges/conv_op_test.cc @@ -54,7 +54,7 @@ void conv_ref(const std::shared_ptr op) { int stride_h = strides[0]; int dila_w = dilations[1]; int dila_h = dilations[0]; - int pad_w = paddings[1]; + int pad_w = paddings[2]; int pad_h = paddings[0]; int batch_size = input_dims[0]; int in_ch_size = input_dims[1]; @@ -175,7 +175,8 @@ void test_conv(int bs, opdesc.SetOutput("Output", {output_var_name}); opdesc.SetAttr("dilations", std::vector({dilation, dilation})); opdesc.SetAttr("strides", std::vector({stride, stride})); - opdesc.SetAttr("paddings", std::vector({padding, padding})); + opdesc.SetAttr("paddings", + std::vector({padding, padding, padding, padding})); opdesc.SetAttr("groups", groups); opdesc.SetAttr("fuse_relu", static_cast(fuse_relu)); if (has_bias) { diff --git a/lite/kernels/npu/bridges/conv_transpose_op.cc b/lite/kernels/npu/bridges/conv_transpose_op.cc index 04f75a91b8..52b38bb505 100644 --- a/lite/kernels/npu/bridges/conv_transpose_op.cc +++ b/lite/kernels/npu/bridges/conv_transpose_op.cc @@ -44,14 +44,19 @@ node_map_type ConvTransposeConverter( auto groups = op_info->GetAttr("groups"); auto dilations = op_info->GetAttr>("dilations"); auto fuse_relu = op_info->GetAttr("fuse_relu"); - CHECK_EQ(strides.size(), 2); - CHECK_EQ(paddings.size(), 2); - CHECK_EQ(dilations.size(), 2); + CHECK_EQ(strides.size(), 2L); + CHECK_EQ(paddings.size(), 4L); + CHECK_EQ(dilations.size(), 2L); // create deconv node auto conv_transpose_node = std::make_shared(unique_op_type); - + bool pad_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + if (!pad_equal) { + LOG(FATA) << "This pad not support ! " << paddings[0] << ", " << paddings[1] + << ", " << paddings[2] << ", " << paddings[3]; + } // create input sizes node to describe the dimensions of input tensor std::vector output_shape; output_shape.push_back(input_shape[0]); diff --git a/lite/kernels/npu/bridges/conv_transpose_op_test.cc b/lite/kernels/npu/bridges/conv_transpose_op_test.cc index a009ef588e..f96e57c06f 100644 --- a/lite/kernels/npu/bridges/conv_transpose_op_test.cc +++ b/lite/kernels/npu/bridges/conv_transpose_op_test.cc @@ -278,7 +278,8 @@ void test_conv_transpose(int bs, opdesc.SetOutput("Output", {output_var_name}); opdesc.SetAttr("dilations", std::vector({dilation, dilation})); opdesc.SetAttr("strides", std::vector({stride, stride})); - opdesc.SetAttr("paddings", std::vector({padding, padding})); + opdesc.SetAttr("paddings", + std::vector({padding, padding, padding, padding})); opdesc.SetAttr("groups", groups); opdesc.SetAttr("fuse_relu", static_cast(fuse_relu)); if (has_bias) { diff --git a/lite/kernels/opencl/conv_compute.cc b/lite/kernels/opencl/conv_compute.cc index 04a78face2..e13d12ec22 100644 --- a/lite/kernels/opencl/conv_compute.cc +++ b/lite/kernels/opencl/conv_compute.cc @@ -38,15 +38,20 @@ void ConvCompute::PrepareForRun() { int w_out = output_dims[3]; int kernel_h = filter_dims[2]; // oihw int kernel_w = filter_dims[3]; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + auto dilations = *param.dilations; int stride_h = param.strides[0]; int stride_w = param.strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int groups = param.groups; bool relu_fused = param.fuse_relu; - bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); + bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool zero_pad = (pad_h == 0) && (pad_w == 0); + bool pad_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + VLOG(3) << "Is relu fused? / " << (relu_fused ? "Yes" : "No"); VLOG(3) << "groups:" << groups << " stride_h:" << stride_h << " stride_w:" << stride_w << " pad_h:" << pad_h @@ -60,7 +65,7 @@ void ConvCompute::PrepareForRun() { << filter_dims[2] << " " << filter_dims[3]; if (kernel_h == 1 && kernel_w == 1 && stride_h == 1 && stride_w == 1 && - zero_pad && no_dilation) { + zero_pad && no_dilation && pad_equal) { // conv2d_1x1 kernel_func_names_.push_back("gemm_batch"); kernel_func_paths_.push_back("buffer/fc_kernel.cl"); @@ -70,7 +75,7 @@ void ConvCompute::PrepareForRun() { build_options_.push_back("-DCL_DTYPE=float"); } impl_ = &ConvCompute::Conv2d1x1; - } else { + } else if (pad_equal) { kernel_func_names_.push_back("im2col"); kernel_func_names_.push_back("gemm_batch"); kernel_func_paths_.push_back("buffer/im2col_kernel.cl"); @@ -85,6 +90,9 @@ void ConvCompute::PrepareForRun() { col_buffer_.reset(new lite::Tensor); col_buffer_->Resize({bs, c_in, kernel_h * kernel_w, h_out * w_out}); col_buffer_->mutable_data(TARGET(kOpenCL)); + } else { + LOG(FATAL) << "This pad not support ! " << paddings[0] << ", " + << paddings[1] << ", " << paddings[2] << ", " << paddings[3]; } for (size_t i = 0; i < kernel_func_names_.size(); i++) { @@ -102,17 +110,19 @@ void ConvCompute::GemmlikeConv2d() { int c_in = x_dims[1]; int h_in = x_dims[2]; int w_in = x_dims[3]; + auto paddings = *param.paddings; + auto dilations = *param.dilations; int c_out = output_dims[1]; int h_out = output_dims[2]; int w_out = output_dims[3]; int kernel_h = filter_dims[2]; int kernel_w = filter_dims[3]; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride_h = param.strides[0]; int stride_w = param.strides[1]; - int dilation_h = param.dilations[0]; - int dilation_w = param.dilations[1]; + int dilation_h = dilations[0]; + int dilation_w = dilations[1]; auto* x_buf = param.x->data(); auto* filter_buf = param.filter->data(); diff --git a/lite/kernels/opencl/conv_compute_test.cc b/lite/kernels/opencl/conv_compute_test.cc index a7417e3525..3bc7a0734d 100644 --- a/lite/kernels/opencl/conv_compute_test.cc +++ b/lite/kernels/opencl/conv_compute_test.cc @@ -24,7 +24,6 @@ namespace lite { #define A(i, j) a[i * lda + j] #define B(i, j) cur_b[i * ldb + j] #define C(i, j) cur_c[i * ldc + j] - template static void conv_basic(const Dtype1* din, Dtype2* dout, @@ -227,10 +226,12 @@ TEST(conv2d, compute_conv2d_1x1) { param.bias = bias_flag ? &bias : nullptr; param.output = &out; param.strides = {stride, stride}; - param.paddings = {pad, pad}; + std::vector paddings = {pad, pad, pad, pad}; param.groups = group; - param.dilations = {dilation, dilation}; + std::vector dilations = {dilation, dilation}; param.fuse_relu = relu_flag; + param.paddings = std::make_shared>(paddings); + param.dilations = std::make_shared>(dilations); kernel->SetParam(param); std::unique_ptr conv_context(new KernelContext); @@ -454,11 +455,14 @@ TEST(conv2d, compute_conv2d_gemm) { param.bias = bias_flag ? &bias : nullptr; param.output = &out; param.strides = {stride, stride}; - param.paddings = {pad, pad}; + std::vector paddings = {pad, pad, pad, pad}; param.groups = group; - param.dilations = {dilation, dilation}; + std::vector dilations = {dilation, dilation}; param.fuse_relu = relu_flag; + param.paddings = std::make_shared>(paddings); + param.dilations = std::make_shared>(dilations); + kernel->SetParam(param); std::unique_ptr conv_context(new KernelContext); context->As().CopySharedTo( diff --git a/lite/kernels/opencl/depthwise_conv2d_compute.cc b/lite/kernels/opencl/depthwise_conv2d_compute.cc index 62734610e2..ed942d7f0c 100644 --- a/lite/kernels/opencl/depthwise_conv2d_compute.cc +++ b/lite/kernels/opencl/depthwise_conv2d_compute.cc @@ -44,7 +44,7 @@ class DepthwiseConv2dCompute auto x_dims = param.x->dims(); auto filter_dims = param.filter->dims(); auto output_dims = param.output->dims(); - auto paddings = param.paddings; + auto paddings = *param.paddings; auto strides = param.strides; auto& context = ctx_->As(); diff --git a/lite/kernels/opencl/depthwise_conv2d_compute_test.cc b/lite/kernels/opencl/depthwise_conv2d_compute_test.cc index a189acaf91..3556d1abed 100644 --- a/lite/kernels/opencl/depthwise_conv2d_compute_test.cc +++ b/lite/kernels/opencl/depthwise_conv2d_compute_test.cc @@ -105,7 +105,8 @@ TEST(depthwise_conv2d, compute) { param.x = &input; param.filter = &filter; param.output = &output; - param.paddings = std::vector{0, 0}; + std::vector paddings = {0, 0}; + param.paddings = std::make_shared>(paddings); param.strides = std::vector{1, 1}; std::unique_ptr context(new KernelContext); diff --git a/lite/kernels/x86/conv_compute.h b/lite/kernels/x86/conv_compute.h index 62f5887947..063e66a1c1 100644 --- a/lite/kernels/x86/conv_compute.h +++ b/lite/kernels/x86/conv_compute.h @@ -67,7 +67,7 @@ class Conv2dCompute : public KernelLite { lite::DDim col_shape(col_shape_vec); lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim); bool is_expand = IsExpand( - filter_shape_vec, param.strides, param.paddings, param.dilations); + filter_shape_vec, param.strides, *param.paddings, *param.dilations); lite::Tensor col; lite::Tensor col_matrix; if (is_expand) { @@ -103,7 +103,7 @@ class Conv2dCompute : public KernelLite { lite::Tensor in_slice = in_batch.Slice(static_cast(g * in_step), static_cast((g + 1) * in_step)); - + auto paddings = *param.paddings; if (!is_expand) { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); @@ -112,20 +112,18 @@ class Conv2dCompute : public KernelLite { // im2col im2col(context, in_slice, - param.dilations, + *param.dilations, param.strides, - std::vector{param.paddings[0], - param.paddings[1], - param.paddings[0], - param.paddings[1]}, + std::vector{ + paddings[0], paddings[2], paddings[0], paddings[2]}, &(col)); } else if (data_dim == 3U) { // vol2col vol2col(context, in_slice, - param.dilations, + *param.dilations, param.strides, - param.paddings, + *param.paddings, &(col)); } diff --git a/lite/kernels/x86/conv_compute_test.cc b/lite/kernels/x86/conv_compute_test.cc index f2dde962b9..2827c6577e 100644 --- a/lite/kernels/x86/conv_compute_test.cc +++ b/lite/kernels/x86/conv_compute_test.cc @@ -73,9 +73,11 @@ TEST(conv2d_x86, run_test) { param.bias = &b; param.output = &out; param.strides = {1, 1}; - param.paddings = {0, 0}; + std::vector paddings = {0, 0, 0, 0}; param.groups = 1; - param.dilations = {1, 1}; + std::vector dilations = {1, 1}; + param.paddings = std::make_shared>(paddings); + param.dilations = std::make_shared>(dilations); LOG(INFO) << 123; std::unique_ptr ctx(new KernelContext); ctx->As(); diff --git a/lite/kernels/xpu/bridges/conv_op.cc b/lite/kernels/xpu/bridges/conv_op.cc index 2c758cf950..9acb0e4e3d 100644 --- a/lite/kernels/xpu/bridges/conv_op.cc +++ b/lite/kernels/xpu/bridges/conv_op.cc @@ -46,17 +46,25 @@ node_map_type ConvConverter(const std::shared_ptr op, auto groups = op_info->GetAttr("groups"); auto dilations = op_info->GetAttr>("dilations"); auto fuse_relu = op_info->GetAttr("fuse_relu"); - CHECK_EQ(strides.size(), 2); - CHECK_EQ(paddings.size(), 2); - CHECK_EQ(dilations.size(), 2); + CHECK_EQ(strides.size(), 2L); + CHECK_EQ(paddings.size(), 4L); + CHECK_EQ(dilations.size(), 2L); std::vector output_shape({bs, oc}); for (size_t i = 0; i < 2; i++) { const int dkernel = dilations[i] * (filter_dims[2 + i] - 1) + 1; output_shape.push_back( - (input_dims[i + 2] + 2 * paddings[i] - dkernel) / strides[i] + 1); + (input_dims[i + 2] + paddings[2 * i] + paddings[2 * i + 1] - dkernel) / + strides[i] + + 1); } DDim output_dims(output_shape); + bool pads_equal = + (paddings[0] == paddings[1]) && (paddings[2] == paddings[3]); + if (!pads_equal) { + LOG(FATAL) << "Padding requies pad_top==pad_bottom and pad_lef==pad_right."; + } + // check context CHECK(graph_ctx != nullptr); CHECK(graph_ctx->builder != nullptr); diff --git a/lite/kernels/xpu/bridges/conv_op_test.cc b/lite/kernels/xpu/bridges/conv_op_test.cc index ebdb67bd0d..70929ffcd5 100644 --- a/lite/kernels/xpu/bridges/conv_op_test.cc +++ b/lite/kernels/xpu/bridges/conv_op_test.cc @@ -54,7 +54,7 @@ void conv_ref(const std::shared_ptr op) { int stride_h = strides[0]; int dila_w = dilations[1]; int dila_h = dilations[0]; - int pad_w = paddings[1]; + int pad_w = paddings[2]; int pad_h = paddings[0]; int batch_size = input_dims[0]; int in_ch_size = input_dims[1]; @@ -175,7 +175,8 @@ void test_conv(int bs, opdesc.SetOutput("Output", {output_var_name}); opdesc.SetAttr("dilations", std::vector({dilation, dilation})); opdesc.SetAttr("strides", std::vector({stride, stride})); - opdesc.SetAttr("paddings", std::vector({padding, padding})); + opdesc.SetAttr("paddings", + std::vector({padding, padding, padding, padding})); opdesc.SetAttr("groups", groups); opdesc.SetAttr("fuse_relu", static_cast(fuse_relu)); if (has_bias) { diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index ceca1a61ce..56bb95da01 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -39,11 +39,15 @@ bool ConvOpLite::CheckShape() const { return true; } -inline int ConvOutputSize( - int input_size, int filter_size, int dilation, int padding, int stride) { +inline int ConvOutputSize(int input_size, + int filter_size, + int dilation, + int pad_left, + int pad_right, + int stride) { const int dkernel = dilation * (filter_size - 1) + 1; - int output_size = (input_size + 2 * padding - dkernel) / stride + 1; - // CHECK_GT_OR_FALSE(output_size, 0); + int output_size = + (input_size + (pad_left + pad_right) - dkernel) / stride + 1; return output_size; } @@ -61,8 +65,11 @@ inline void UpdatePaddingAndDilation(std::vector* paddings, int pad_sum = std::max((out_size - 1) * strides[i] + ksize[i] - data_dims[i + 2], (int64_t)0); + int pad_0 = pad_sum / 2; + int pad_1 = pad_sum - pad_0; // pad - *(paddings->begin() + i) = pad_sum / 2; + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; // dilation *(dilations->begin() + i) = 1; } @@ -77,18 +84,21 @@ bool ConvOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); - UpdatePaddingAndDilation(¶m_.paddings, - ¶m_.dilations, + UpdatePaddingAndDilation(param_.paddings.get(), + param_.dilations.get(), param_.strides, padding_algorithm_, in_dims, filter_dims); std::vector output_shape({in_dims[0], filter_dims[0]}); + auto paddings = *param_.paddings; + auto dilations = *param_.dilations; for (size_t i = 0; i < param_.strides.size(); ++i) { output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], - param_.dilations[i], - param_.paddings[i], + dilations[i], + paddings[i * 2], + paddings[i * 2 + 1], param_.strides[i])); } diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 1d6e1c9349..24848803fb 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include "lite/core/kernel.h" @@ -47,9 +48,10 @@ class ConvOpLite : public OpLite { param_.output = scope->FindVar(Out)->GetMutable(); param_.strides = op_desc.GetAttr>("strides"); - param_.paddings = op_desc.GetAttr>("paddings"); + auto paddings = op_desc.GetAttr>("paddings"); param_.groups = op_desc.GetAttr("groups"); - param_.dilations = op_desc.GetAttr>("dilations"); + auto dilations = op_desc.GetAttr>("dilations"); + param_.dilations = std::make_shared>(dilations); // optional params std::vector input_arg_names = op_desc.InputArgumentNames(); @@ -109,6 +111,20 @@ class ConvOpLite : public OpLite { param_.output_scale = op_desc.GetAttr("output_scale"); } } + + // 2-pad to 4-pad + if (paddings.size() == 2L) { + for (size_t i = 0; i < param_.strides.size(); ++i) { + int copy_pad = *(paddings.begin() + 2 * i); + paddings.insert(paddings.begin() + 2 * i + 1, copy_pad); + } + } else { + if (paddings.size() != 4L) { + LOG(FATAL) + << "Paddings size should be the same or twice as the input size."; + } + } + param_.paddings = std::make_shared>(paddings); return true; } diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index fb6b431fff..aca3fbaaa3 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -11,8 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "lite/operators/conv_transpose_op.h" +#include #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" @@ -32,7 +32,6 @@ bool ConvTransposeOpLite::CheckShape() const { CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size()); CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U); - CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size()); CHECK_OR_FALSE(in_dims[1] % param_.groups == 0); return true; @@ -42,13 +41,16 @@ bool ConvTransposeOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); + auto paddings = *param_.paddings; + auto dilations = *param_.dilations; + std::vector output_shape; output_shape.push_back(in_dims[0]); output_shape.push_back(filter_dims[1] * param_.groups); for (int i = 0; i < param_.strides.size(); i++) { - int kernel_extent = param_.dilations[i] * (filter_dims[i + 2] - 1) + 1; + int kernel_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; int output_len = (in_dims[i + 2] - 1) * param_.strides[i] + kernel_extent - - 2 * param_.paddings[i]; + (paddings[2 * i] + paddings[2 * i + 1]); output_shape.push_back(output_len); } @@ -68,9 +70,24 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, param_.output = scope->FindVar(Out)->GetMutable(); param_.strides = op_desc.GetAttr>("strides"); - param_.paddings = op_desc.GetAttr>("paddings"); + auto paddings = op_desc.GetAttr>("paddings"); param_.groups = op_desc.GetAttr("groups"); - param_.dilations = op_desc.GetAttr>("dilations"); + auto dilations = op_desc.GetAttr>("dilations"); + + // 2-pad to 4-pad + if (paddings.size() == 2L) { + for (size_t i = 0; i < 2L; ++i) { + int copy_pad = *(paddings.begin() + 2 * i); + paddings.insert(paddings.begin() + 2 * i + 1, copy_pad); + } + } else { + if (paddings.size() != 4L) { + LOG(FATAL) + << "Paddings size should be the same or twice as the input size."; + } + } + param_.paddings = std::make_shared>(paddings); + param_.dilations = std::make_shared>(dilations); // optional params std::vector input_arg_names = op_desc.InputArgumentNames(); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index d455743c4d..0e9f85e060 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -254,9 +254,19 @@ struct ConvParam { lite::Tensor* residualData{nullptr}; lite::Tensor* output{}; std::vector strides{1, 1}; - std::vector paddings{0, 0}; + /* paddings type change + * from std::vector to std::shared_ptr> + * to support dynamically modify padding + * let kernel param and operator param Synchronous update + */ + std::shared_ptr> paddings; int groups{1}; - std::vector dilations{1, 1}; + /* dilations type change + * from std::vector to std::shared_ptr> + * to support dynamically modify padding + * let kernel param and operator param Synchronous update + */ + std::shared_ptr> dilations; bool fuse_relu_before_depthwise_conv{false}; bool use_mkldnn{false}; bool fuse_relu{false}; // only used in mkldnn kernel diff --git a/lite/tests/math/conv_compute_test.cc b/lite/tests/math/conv_compute_test.cc index bfb74e6e0a..194d7ab1c3 100644 --- a/lite/tests/math/conv_compute_test.cc +++ b/lite/tests/math/conv_compute_test.cc @@ -64,21 +64,25 @@ using paddle::lite::Timer; DDim compute_out_dim(const DDim& dim_in, const paddle::lite::operators::ConvParam& param) { DDim dim_out = dim_in; + auto paddings = *param.paddings; + auto dilations = *param.dilations; dim_out[1] = param.filter->dims()[0]; auto kernel_h = param.filter->dims()[2]; auto kernel_w = param.filter->dims()[3]; auto h = dim_in[2]; auto w = dim_in[3]; - int dila_h = param.dilations[0]; - int dila_w = param.dilations[1]; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int dila_h = dilations[0]; + int dila_w = dilations[1]; + int pad_top = paddings[0]; + int pad_bottom = paddings[1]; + int pad_left = paddings[2]; + int pad_right = paddings[3]; int stride_h = param.strides[0]; int stride_w = param.strides[1]; auto kernel_exten = dila_h * (kernel_h - 1) + 1; - auto hout = (h + 2 * pad_h - kernel_exten) / stride_h + 1; + auto hout = (h + pad_top + pad_bottom - kernel_exten) / stride_h + 1; kernel_exten = dila_w * (kernel_w - 1) + 1; - auto wout = (w + 2 * pad_w - kernel_exten) / stride_w + 1; + auto wout = (w + pad_left + pad_right - kernel_exten) / stride_w + 1; dim_out[2] = hout; dim_out[3] = wout; return dim_out; @@ -110,8 +114,8 @@ void test_conv_fp32(const std::vector& input_dims, param.bias->set_precision(PRECISION(kFloat)); } param.strides = strides; - param.paddings = pads; - param.dilations = dilas; + param.paddings = std::make_shared>(pads); + param.dilations = std::make_shared>(dilas); param.fuse_relu = flag_relu; param.groups = group; @@ -162,7 +166,7 @@ void test_conv_fp32(const std::vector& input_dims, param.output->Resize(dim_out); paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f); - // paddle::lite::fill_tensor_const(*param.x, 1.f); + // paddle::lite::fill_tensor_const(*param.x, 1.f); auto din = param.x->data(); Tensor tout_basic; @@ -189,7 +193,7 @@ void test_conv_fp32(const std::vector& input_dims, strides[0], dilas[1], dilas[0], - pads[1], + pads[2], pads[0], flag_bias, flag_relu); @@ -235,7 +239,8 @@ void test_conv_fp32(const std::vector& input_dims, LOG(FATAL) << "test fp32 conv: input: " << dim_in << ", output: " << dim_out << ", weight dim: " << weight_dim - << ", pad: " << pads[0] << ", " << pads[1] + << ", pad: " << pads[0] << ", " << pads[1] << ", " + << pads[2] << ", " << pads[3] << ", stride: " << strides[0] << ", " << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] << ", bias: " << (flag_bias ? "true" : "false") @@ -280,27 +285,33 @@ void test_conv_fp32(const std::vector& input_dims, TEST(TestConv3x3DW, test_conv3x3_depthwise) { if (FLAGS_basic_test) { for (auto& stride : {1, 2}) { - for (auto& pad : {0, 1}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - for (auto& c : {1, 3, 5, 8, 16, 32}) { - std::vector dims; - DDim weights_dim({c, 1, 3, 3}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { - dims.push_back(DDim({batch, c, h, h})); + for (auto& pad_left : {0, 1, 2}) { + for (auto& pad_right : {0, 1, 2}) { + for (auto& pad_top : {0, 1, 2}) { + for (auto& pad_bottom : {0, 1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + for (auto& c : {1, 3, 5, 8, 16, 32}) { + std::vector dims; + DDim weights_dim({c, 1, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { + dims.push_back(DDim({batch, c, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + c, + {stride, stride}, + {pad_top, pad_bottom, pad_left, pad_right}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } } } - test_conv_fp32(dims, - weights_dim, - c, - {stride, stride}, - {pad, pad}, - {1, 1}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -329,7 +340,7 @@ TEST(TestConv5x5DW, test_conv5x5_depthwise) { weights_dim, c, {stride, stride}, - {pad, pad}, + {pad, pad, pad, pad}, {1, 1}, flag_bias, flag_relu, @@ -366,7 +377,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) { weights_dim, g, {1, 1}, - {0, 0}, + {0, 0, 0, 0}, {1, 1}, flag_bias, flag_relu, @@ -386,26 +397,32 @@ TEST(TestConv3x3s1, test_conv_3x3s1) { if (FLAGS_basic_test) { for (auto& cin : {1, 3, 8, 32, 48}) { for (auto& cout : {1, 5, 8, 32, 48}) { - for (auto& pad : {1, 2}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - std::vector dims; - DDim weights_dim({cout, cin, 3, 3}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 56, 32}) { - dims.push_back(DDim({batch, cin, h, h})); + for (auto& pad_left : {1, 2}) { + for (auto& pad_right : {1, 2}) { + for (auto& pad_top : {1, 2}) { + for (auto& pad_bottom : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + 1, + {1, 1}, + {pad_top, pad_bottom, pad_left, pad_right}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } } } - test_conv_fp32(dims, - weights_dim, - 1, - {1, 1}, - {pad, pad}, - {1, 1}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -420,26 +437,32 @@ TEST(TestConv3x3s2, test_conv_3x3s2) { if (FLAGS_basic_test) { for (auto& cin : {1, 3, 8, 32}) { for (auto& cout : {1, 5, 8, 32}) { - for (auto& pad : {1, 2}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - std::vector dims; - DDim weights_dim({cout, cin, 3, 3}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { - dims.push_back(DDim({batch, cin, h, h})); + for (auto& pad_left : {1, 2}) { + for (auto& pad_right : {1, 2}) { + for (auto& pad_top : {1, 2}) { + for (auto& pad_bottom : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_fp32(dims, + weights_dim, + 1, + {2, 2}, + {pad_top, pad_bottom, pad_left, pad_right}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } } } - test_conv_fp32(dims, - weights_dim, - 1, - {2, 2}, - {pad, pad}, - {1, 1}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -458,30 +481,37 @@ TEST(TestConvRand, test_conv_rand) { for (auto& kw : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) { for (auto& stride : {1, 2}) { - for (auto& pad : {0, 1, 2}) { - for (auto& dila : {1, 2}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - if (cin % g != 0 || cout % g != 0) { - continue; - } - std::vector dims; - DDim weights_dim({cout, cin / g, kh, kw}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 19, 32, 28}) { - dims.push_back(DDim({batch, cin, h, h})); + for (auto& pad_left : {0, 1, 2}) { + for (auto& pad_right : {0, 1, 2}) { + for (auto& pad_top : {0, 1, 2}) { + for (auto& pad_bottom : {0, 1, 2}) { + for (auto& dila : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + if (cin % g != 0 || cout % g != 0) { + continue; + } + std::vector dims; + DDim weights_dim({cout, cin / g, kh, kw}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 19, 32, 28}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_fp32( + dims, + weights_dim, + g, + {stride, stride}, + {pad_top, pad_bottom, pad_left, pad_right}, + {dila, dila}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } } } - test_conv_fp32(dims, - weights_dim, - g, - {stride, stride}, - {pad, pad}, - {dila, dila}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -510,7 +540,7 @@ TEST(TestConvCustom, test_conv_fp32_custom_size) { FLAGS_kernel_w}), FLAGS_group, {FLAGS_stride_h, FLAGS_stride_w}, - {FLAGS_pad_h, FLAGS_pad_w}, + {FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w}, {FLAGS_dila_h, FLAGS_dila_w}, FLAGS_flag_bias, FLAGS_flag_relu, diff --git a/lite/tests/math/conv_int8_compute_test.cc b/lite/tests/math/conv_int8_compute_test.cc index e15b7d22bc..6af9bbd431 100644 --- a/lite/tests/math/conv_int8_compute_test.cc +++ b/lite/tests/math/conv_int8_compute_test.cc @@ -63,22 +63,22 @@ using paddle::lite::Timer; DDim compute_out_dim(const DDim& dim_in, const paddle::lite::operators::ConvParam& param) { + auto paddings = *param.paddings; + auto dilations = *param.dilations; DDim dim_out = dim_in; dim_out[1] = param.filter->dims()[0]; auto kernel_h = param.filter->dims()[2]; auto kernel_w = param.filter->dims()[3]; auto h = dim_in[2]; auto w = dim_in[3]; - int dila_h = param.dilations[0]; - int dila_w = param.dilations[1]; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int dila_h = dilations[0]; + int dila_w = dilations[1]; int stride_h = param.strides[0]; int stride_w = param.strides[1]; auto kernel_exten = dila_h * (kernel_h - 1) + 1; - auto hout = (h + 2 * pad_h - kernel_exten) / stride_h + 1; + auto hout = (h + paddings[0] + paddings[1] - kernel_exten) / stride_h + 1; kernel_exten = dila_w * (kernel_w - 1) + 1; - auto wout = (w + 2 * pad_w - kernel_exten) / stride_w + 1; + auto wout = (w + paddings[2] + paddings[3] - kernel_exten) / stride_w + 1; dim_out[2] = hout; dim_out[3] = wout; return dim_out; @@ -104,8 +104,8 @@ void get_conv_param(const DDim& dim_w, param->bias->set_precision(PRECISION(kFloat)); } param->strides = strides; - param->paddings = pads; - param->dilations = dila; + param->paddings = std::make_shared>(pads); + param->dilations = std::make_shared>(dila); param->fuse_relu = flag_relu; param->groups = g; @@ -288,7 +288,7 @@ void test_conv_int8(const std::vector& input_dims, strides[0], dilas[1], dilas[0], - pads[1], + pads[2], pads[0], flag_bias, flag_relu); @@ -358,7 +358,8 @@ void test_conv_int8(const std::vector& input_dims, LOG(FATAL) << "test int8 conv, fp32 out: input: " << dim_in << ", output: " << dim_out << ", weight dim: " << weight_dim - << ", pad: " << pads[0] << ", " << pads[1] + << ", pad: " << pads[0] << ", " << pads[1] << ", " + << pads[2] << ", " << pads[3] << ", stride: " << strides[0] << ", " << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] << ", bias: " << (flag_bias ? "true" : "false") @@ -416,7 +417,8 @@ void test_conv_int8(const std::vector& input_dims, LOG(FATAL) << "test int8 conv, int8 out: input: " << dim_in << ", output: " << dim_out << ", weight dim: " << weight_dim - << ", pad: " << pads[0] << ", " << pads[1] + << ", pad: " << pads[0] << ", " << pads[1] << ", " + << pads[2] << ", " << pads[3] << ", stride: " << strides[0] << ", " << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] << ", bias: " << (flag_bias ? "true" : "false") @@ -428,9 +430,9 @@ void test_conv_int8(const std::vector& input_dims, } LOG(INFO) << "test int8 conv: input: " << dim_in << ", output: " << dim_out << ", weight dim: " << weight_dim - << ", pad: " << pads[0] << ", " << pads[1] - << ", stride: " << strides[0] << ", " << strides[1] - << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2] + << ", " << pads[3] << ", stride: " << strides[0] << ", " + << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] << ", bias: " << (flag_bias ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false") << ", threads: " << th << ", power_mode: " << cls @@ -473,7 +475,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { weights_dim, c, {stride, stride}, - {pad, pad}, + {pad, pad, pad, pad}, {1, 1}, flag_bias, flag_relu, @@ -507,7 +509,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { weights_dim, c, {stride, stride}, - {pad, pad}, + {pad, pad, pad, pad}, {1, 1}, flag_bias, flag_relu, @@ -544,7 +546,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { weights_dim, g, {1, 1}, - {0, 0}, + {0, 0, 0, 0}, {1, 1}, flag_bias, flag_relu, @@ -564,26 +566,32 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { if (FLAGS_basic_test) { for (auto& cin : {1, 3, 8, 32, 48}) { for (auto& cout : {1, 5, 8, 32, 48}) { - for (auto& pad : {1, 2}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - std::vector dims; - DDim weights_dim({cout, cin, 3, 3}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 56, 32}) { - dims.push_back(DDim({batch, cin, h, h})); + for (auto& pad_top : {1, 2}) { + for (auto& pad_bottom : {1, 2}) { + for (auto& pad_left : {1, 2}) { + for (auto& pad_right : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + 1, + {1, 1}, + {pad_top, pad_bottom, pad_left, pad_right}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } } } - test_conv_int8(dims, - weights_dim, - 1, - {1, 1}, - {pad, pad}, - {1, 1}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -598,26 +606,32 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { if (FLAGS_basic_test) { for (auto& cin : {1, 3, 8, 32}) { for (auto& cout : {1, 5, 8, 32}) { - for (auto& pad : {1, 2}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - std::vector dims; - DDim weights_dim({cout, cin, 3, 3}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { - dims.push_back(DDim({batch, cin, h, h})); + for (auto& pad_top : {1, 2}) { + for (auto& pad_bottom : {1, 2}) { + for (auto& pad_left : {1, 2}) { + for (auto& pad_right : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + std::vector dims; + DDim weights_dim({cout, cin, 3, 3}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_int8(dims, + weights_dim, + 1, + {2, 2}, + {pad_top, pad_bottom, pad_left, pad_right}, + {1, 1}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } } } - test_conv_int8(dims, - weights_dim, - 1, - {2, 2}, - {pad, pad}, - {1, 1}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -636,30 +650,37 @@ TEST(TestConvRandInt8, test_conv_rand) { for (auto& kw : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) { for (auto& stride : {1, 2}) { - for (auto& pad : {0, 1, 2}) { - for (auto& dila : {1, 2}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - if (cin % g != 0 || cout % g != 0) { - continue; - } - std::vector dims; - DDim weights_dim({cout, cin / g, kh, kw}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 19, 32, 28}) { - dims.push_back(DDim({batch, cin, h, h})); + for (auto& pad_top : {0, 1, 2}) { + for (auto& pad_bottom : {0, 1, 2}) { + for (auto& pad_left : {0, 1, 2}) { + for (auto& pad_right : {0, 1, 2}) { + for (auto& dila : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + if (cin % g != 0 || cout % g != 0) { + continue; + } + std::vector dims; + DDim weights_dim({cout, cin / g, kh, kw}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 19, 32, 28}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_int8( + dims, + weights_dim, + g, + {stride, stride}, + {pad_top, pad_bottom, pad_left, pad_right}, + {dila, dila}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_power_mode}); + } } } - test_conv_int8(dims, - weights_dim, - g, - {stride, stride}, - {pad, pad}, - {dila, dila}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -688,7 +709,7 @@ TEST(TestConvCustomInt8, test_conv_custom_size) { FLAGS_kernel_w}), FLAGS_group, {FLAGS_stride_h, FLAGS_stride_w}, - {FLAGS_pad_h, FLAGS_pad_w}, + {FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w}, {FLAGS_dila_h, FLAGS_dila_w}, FLAGS_flag_bias, FLAGS_flag_relu, diff --git a/lite/tests/math/conv_transpose_compute_test.cc b/lite/tests/math/conv_transpose_compute_test.cc index e0da07a534..fd2d5195a3 100644 --- a/lite/tests/math/conv_transpose_compute_test.cc +++ b/lite/tests/math/conv_transpose_compute_test.cc @@ -66,10 +66,12 @@ DDim compute_out_dim(const DDim& dim_in, auto filter_dims = param.filter->dims(); DDim output_shape = dim_in; output_shape[1] = filter_dims[1] * param.groups; + auto paddings = *param.paddings; + auto dilations = *param.dilations; for (int i = 0; i < 2; i++) { - int kernel_extent = param.dilations[i] * (filter_dims[i + 2] - 1) + 1; + int kernel_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; int output_len = (dim_in[i + 2] - 1) * param.strides[i] + kernel_extent - - 2 * param.paddings[i]; + (paddings[2 * i] + paddings[2 * i + 1]); output_shape[i + 2] = output_len; } return output_shape; @@ -101,8 +103,8 @@ void test_conv_transpose_fp32(const std::vector& input_dims, param.bias->set_precision(PRECISION(kFloat)); } param.strides = strides; - param.paddings = pads; - param.dilations = dilas; + param.paddings = std::make_shared>(pads); + param.dilations = std::make_shared>(dilas); param.fuse_relu = flag_relu; param.groups = group; @@ -182,7 +184,7 @@ void test_conv_transpose_fp32(const std::vector& input_dims, strides[0], dilas[1], dilas[0], - pads[1], + pads[2], pads[0], flag_bias, flag_relu); @@ -296,7 +298,7 @@ TEST(TestConvRand, test_conv_transpose_rand) { weights_dim, g, {stride, stride}, - {pad, pad}, + {pad, pad, pad, pad}, {dila, dila}, flag_bias, flag_relu, @@ -330,7 +332,7 @@ TEST(TestConvCustom, test_conv_transpose_fp32_custom_size) { FLAGS_kernel_w}), FLAGS_group, {FLAGS_stride_h, FLAGS_stride_w}, - {FLAGS_pad_h, FLAGS_pad_w}, + {FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w}, {FLAGS_dila_h, FLAGS_dila_w}, FLAGS_flag_bias, FLAGS_flag_relu, -- GitLab