diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 3d690614ba82da2b3581da1247df252dac3e0a48..956beb53c9a9e9d857d9c129d90443b09c0b3bb8 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -130,8 +130,8 @@ inline void WinogradConv3x3(const ConvParam ¶m) { auto winograd_pad = [&](int width, int pad) { int output_tile = tile - kernel + 1; - // int tiles = (width + pad - kernel) / output_tile + 1; - // return (tiles - 1) * output_tile + tile - width; + // int tiles = (width + pad - kernel) / output_tile + 1; + // return (tiles - 1) * output_tile + tile - width; int pad_width = (width + 2 * pad - kernel) / output_tile * output_tile; return pad_width + tile - width; }; @@ -141,8 +141,10 @@ inline void WinogradConv3x3(const ConvParam ¶m) { for (int i = 0; i < batch_size; ++i) { Tensor in_batch = input->Slice(i, i + 1); Tensor out_batch = output->Slice(i, i + 1); - int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]); - int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]); + // int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]); + // int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]); + int pad_bottom = paddings[0]; + int pad_right = paddings[1]; if (paddings[0] || paddings[1] || pad_bottom || pad_right) { framework::DDim pad_shape = in_batch.dims(); pad_shape[2] += paddings[0] + pad_bottom; diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index 77ba052ebf6395e15de64cacb1aac05c14c0a6b1..e2a6d4558b95f8a60988f25a8ce5201c2fa05507 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -336,6 +336,23 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, memset(outptr, 0, output->numel() * sizeof(float)); const float *inptr = input.data(); + int inter_h = (height - 2) / 6; + int inter_w = (width - 2) / 6; + int remain_h = height - (inter_h * 6); + int remain_w = width - (inter_w * 6); + framework::Tensor input_pad; + if (remain_h > 2 || remain_w > 2) { + inter_h += (remain_h > 2); + inter_w += (remain_w > 2); + height = (inter_h - 1) * 6 + 8; + width = (inter_w - 1) * 6 + 8; + framework::DDim input_shape = + framework::make_ddim(std::vector{1, channel, height, width}); + PadFunctor pad; + inptr = input_pad.mutable_data(input_shape); + pad(input, 0, height - input.dims()[2], 0, width - input.dims()[3], + &input_pad); + } size_t image_size = height * width; const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f, 2.f, -1.25f, 0.5f, 0.25f};