提交 6d179089 编写于 作者: X xiaogang 提交者: GitHub

fix: winograd support unsame pad (#3449)

上级 afc14cb8
......@@ -80,8 +80,10 @@ void conv_compute_6x6_3x3(const float* input,
const operators::ConvParam& param,
ARMContext* ctx) {
auto act_param = param.activation_param;
const int pad_h = (*param.paddings)[0];
const int pad_w = (*param.paddings)[2];
const int pad_h0 = (*param.paddings)[0];
const int pad_h1 = (*param.paddings)[1];
const int pad_w0 = (*param.paddings)[2];
const int pad_w1 = (*param.paddings)[3];
float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);
......@@ -96,8 +98,8 @@ void conv_compute_6x6_3x3(const float* input,
int tile_h = (hout + 5) / 6;
int size_tile = tile_h * tile_w;
int w_pad = win + pad_w * 2;
int h_pad = hin + pad_h * 2;
int w_pad = win + pad_w0 + pad_w1;
int h_pad = hin + pad_h0 + pad_h1;
const int zero_len = w_pad;
float zero_ptr[zero_len]; // NOLINT
......@@ -127,10 +129,10 @@ void conv_compute_6x6_3x3(const float* input,
prepack_input_nxwc4_dw(input + ni * in_n_stride,
input_c4 + i * new_c_stride,
i * 4,
-pad_h,
hin + pad_h,
-pad_w,
win + pad_w,
-pad_h0,
hin + pad_h1,
-pad_w0,
win + pad_w1,
chin,
win,
hin,
......@@ -367,8 +369,10 @@ void conv_compute_2x2_3x3(const float* input,
const operators::ConvParam& param,
ARMContext* ctx) {
auto act_param = param.activation_param;
const int pad_h = (*param.paddings)[0];
const int pad_w = (*param.paddings)[2];
const int pad_h0 = (*param.paddings)[0];
const int pad_h1 = (*param.paddings)[1];
const int pad_w0 = (*param.paddings)[2];
const int pad_w1 = (*param.paddings)[3];
float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);
......@@ -383,8 +387,8 @@ void conv_compute_2x2_3x3(const float* input,
int tile_h = (hout + 1) / 2;
int size_tile = tile_h * tile_w;
int w_pad = win + pad_w * 2;
int h_pad = hin + pad_h * 2;
int w_pad = win + pad_w0 + pad_w1;
int h_pad = hin + pad_h0 + pad_h1;
const int zero_len = w_pad;
float zero_ptr[zero_len]; // NOLINT
......@@ -414,10 +418,10 @@ void conv_compute_2x2_3x3(const float* input,
prepack_input_nxwc4_dw(input + ni * in_n_stride,
input_c4 + i * new_c_stride,
i * 4,
-pad_h,
hin + pad_h,
-pad_w,
win + pad_w,
-pad_h0,
hin + pad_h1,
-pad_w0,
win + pad_w1,
chin,
win,
hin,
......@@ -628,8 +632,10 @@ void conv_compute_2x2_3x3_small(const float* input,
const operators::ConvParam& param,
ARMContext* ctx) {
auto act_param = param.activation_param;
const int pad_h = (*param.paddings)[0];
const int pad_w = (*param.paddings)[2];
const int pad_h0 = (*param.paddings)[0];
const int pad_h1 = (*param.paddings)[1];
const int pad_w0 = (*param.paddings)[2];
const int pad_w1 = (*param.paddings)[3];
float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);
......@@ -644,8 +650,8 @@ void conv_compute_2x2_3x3_small(const float* input,
int tile_h = (hout + 1) / 2;
int size_tile = tile_h * tile_w;
int w_pad = win + pad_w * 2;
int h_pad = hin + pad_h * 2;
int w_pad = win + pad_w0 + pad_w1;
int h_pad = hin + pad_h0 + pad_h1;
const int zero_len = w_pad;
float zero_ptr[zero_len]; // NOLINT
......@@ -676,10 +682,10 @@ void conv_compute_2x2_3x3_small(const float* input,
prepack_input_nxwc4_dw(input + ni * in_n_stride,
input_c4 + i * new_c_stride,
i * 4,
-pad_h,
hin + pad_h,
-pad_w,
win + pad_w,
-pad_h0,
hin + pad_h1,
-pad_w0,
win + pad_w1,
chin,
win,
hin,
......
......@@ -72,7 +72,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal &&
no_dilation && pads_all_equal) {
no_dilation) {
// TODO(MyPandaShaoxiang): winograd conv support any pad
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking winograd conv";
......
......@@ -45,12 +45,14 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int ow = o_dims[3];
int tile_block = 8;
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int pad_h0 = pad[0];
int pad_h1 = pad[1];
int pad_w0 = pad[2];
int pad_w1 = pad[3];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
(ic + 3) / 4 * 4 * (ih + pad_h0 + pad_h1) * (iw + pad_w0 + pad_w1);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw +
8 * wino_iw * wino_iw) *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册