From 6d1790899583c3c1177f3d83db2577a58feb30d4 Mon Sep 17 00:00:00 2001 From: xiaogang Date: Tue, 21 Apr 2020 11:16:05 +0800 Subject: [PATCH] fix: winograd support unsame pad (#3449) --- .../arm/math/conv3x3_winograd_fp32_c4.cc | 54 ++++++++++--------- lite/kernels/arm/conv_compute.cc | 2 +- lite/kernels/arm/conv_winograd.cc | 8 +-- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc index d1992f62bb..35d9eeaee1 100644 --- a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc +++ b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc @@ -80,8 +80,10 @@ void conv_compute_6x6_3x3(const float* input, const operators::ConvParam& param, ARMContext* ctx) { auto act_param = param.activation_param; - const int pad_h = (*param.paddings)[0]; - const int pad_w = (*param.paddings)[2]; + const int pad_h0 = (*param.paddings)[0]; + const int pad_h1 = (*param.paddings)[1]; + const int pad_w0 = (*param.paddings)[2]; + const int pad_w1 = (*param.paddings)[3]; float* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(float); @@ -96,8 +98,8 @@ void conv_compute_6x6_3x3(const float* input, int tile_h = (hout + 5) / 6; int size_tile = tile_h * tile_w; - int w_pad = win + pad_w * 2; - int h_pad = hin + pad_h * 2; + int w_pad = win + pad_w0 + pad_w1; + int h_pad = hin + pad_h0 + pad_h1; const int zero_len = w_pad; float zero_ptr[zero_len]; // NOLINT @@ -127,10 +129,10 @@ void conv_compute_6x6_3x3(const float* input, prepack_input_nxwc4_dw(input + ni * in_n_stride, input_c4 + i * new_c_stride, i * 4, - -pad_h, - hin + pad_h, - -pad_w, - win + pad_w, + -pad_h0, + hin + pad_h1, + -pad_w0, + win + pad_w1, chin, win, hin, @@ -367,8 +369,10 @@ void conv_compute_2x2_3x3(const float* input, const operators::ConvParam& param, ARMContext* ctx) { auto act_param = param.activation_param; - const int pad_h = (*param.paddings)[0]; - const int pad_w = (*param.paddings)[2]; + const int pad_h0 = (*param.paddings)[0]; + const int pad_h1 = (*param.paddings)[1]; + const int pad_w0 = (*param.paddings)[2]; + const int pad_w1 = (*param.paddings)[3]; float* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(float); @@ -383,8 +387,8 @@ void conv_compute_2x2_3x3(const float* input, int tile_h = (hout + 1) / 2; int size_tile = tile_h * tile_w; - int w_pad = win + pad_w * 2; - int h_pad = hin + pad_h * 2; + int w_pad = win + pad_w0 + pad_w1; + int h_pad = hin + pad_h0 + pad_h1; const int zero_len = w_pad; float zero_ptr[zero_len]; // NOLINT @@ -414,10 +418,10 @@ void conv_compute_2x2_3x3(const float* input, prepack_input_nxwc4_dw(input + ni * in_n_stride, input_c4 + i * new_c_stride, i * 4, - -pad_h, - hin + pad_h, - -pad_w, - win + pad_w, + -pad_h0, + hin + pad_h1, + -pad_w0, + win + pad_w1, chin, win, hin, @@ -628,8 +632,10 @@ void conv_compute_2x2_3x3_small(const float* input, const operators::ConvParam& param, ARMContext* ctx) { auto act_param = param.activation_param; - const int pad_h = (*param.paddings)[0]; - const int pad_w = (*param.paddings)[2]; + const int pad_h0 = (*param.paddings)[0]; + const int pad_h1 = (*param.paddings)[1]; + const int pad_w0 = (*param.paddings)[2]; + const int pad_w1 = (*param.paddings)[3]; float* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(float); @@ -644,8 +650,8 @@ void conv_compute_2x2_3x3_small(const float* input, int tile_h = (hout + 1) / 2; int size_tile = tile_h * tile_w; - int w_pad = win + pad_w * 2; - int h_pad = hin + pad_h * 2; + int w_pad = win + pad_w0 + pad_w1; + int h_pad = hin + pad_h0 + pad_h1; const int zero_len = w_pad; float zero_ptr[zero_len]; // NOLINT @@ -676,10 +682,10 @@ void conv_compute_2x2_3x3_small(const float* input, prepack_input_nxwc4_dw(input + ni * in_n_stride, input_c4 + i * new_c_stride, i * 4, - -pad_h, - hin + pad_h, - -pad_w, - win + pad_w, + -pad_h0, + hin + pad_h1, + -pad_w0, + win + pad_w1, chin, win, hin, diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index fb8529af5a..6baa811737 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -72,7 +72,7 @@ void ConvCompute::PrepareForRun() { impl_ = new DepthwiseConv; // VLOG(3) << "invoking dw conv"; } else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal && - no_dilation && pads_all_equal) { + no_dilation) { // TODO(MyPandaShaoxiang): winograd conv support any pad impl_ = new WinogradConv; // VLOG(3) << "invoking winograd conv"; diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index e433a3f4bb..d0880e51de 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -45,12 +45,14 @@ void WinogradConv::ReInitWhenNeeded() { int ow = o_dims[3]; int tile_block = 8; auto pad = *(param.paddings); - int pad_h = pad[0]; - int pad_w = pad[2]; + int pad_h0 = pad[0]; + int pad_h1 = pad[1]; + int pad_w0 = pad[2]; + int pad_w1 = pad[3]; int oc_pad = (oc + 3) / 4 * 4; int ic_pad = (ic + 3) / 4 * 4; const int new_input_size = - (ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2); + (ic + 3) / 4 * 4 * (ih + pad_h0 + pad_h1) * (iw + pad_w0 + pad_w1); const int temp_size = (tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw + 8 * wino_iw * wino_iw) * -- GitLab