diff --git a/dnn/src/common/winograd/winograd_helper.cpp b/dnn/src/common/winograd/winograd_helper.cpp index 767bdda4ce1768d161393f66e983aef77c3ac7a0..c950cf7b532588165097e567906f31059dbed3e8 100644 --- a/dnn/src/common/winograd/winograd_helper.cpp +++ b/dnn/src/common/winograd/winograd_helper.cpp @@ -247,33 +247,31 @@ void StrategyHelper< Getter getter(dtype); InputVisitor intput_visitor(IC); - rep(ic, IC) { - memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); - rep(i, alpha) rep(j, alpha) { - int ih = ih_start + i; - int iw = iw_start + j; - if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { - mid_buf1[i * alpha + j] = getter( - input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); - } + memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); + rep(i, alpha) rep(j, alpha) { + int ih = ih_start + i; + int iw = iw_start + j; + if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { + mid_buf1[i * alpha + j] = getter( + input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); } + } - megdnn::naive::run_matrix_mul_tpl( - winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, - alpha, alpha, alpha, alpha, alpha, dtype, dtype); - megdnn::naive::run_matrix_mul_tpl( - mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, - alpha, alpha, alpha, alpha, alpha, dtype, dtype); - - rep(i, alpha) rep(j, alpha) { - input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, - unit_idx, i, j)] = - mid_buf1[i * alpha + j]; - } + megdnn::naive::run_matrix_mul_tpl( + winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, + alpha, alpha, alpha, alpha, alpha, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl( + mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, + alpha, alpha, alpha, alpha, alpha, dtype, dtype); + + rep(i, alpha) rep(j, alpha) { + input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, + unit_idx, i, j)] = + mid_buf1[i * alpha + j]; } } @@ -287,7 +285,7 @@ void StrategyHelper< output_compute_type* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, - size_t oc_end, size_t unit_idx, size_t nr_units_in_tile, + size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, const std::vector& interp_points, DType dtype, float input_filter_scale, float input_filter_rescale, @@ -300,49 +298,49 @@ void StrategyHelper< OutputGetter getter(dtype); OutputVisitor output_visitor(oc_end - oc_start); - for (size_t oc = oc_start; oc < oc_end; oc++) { - /* gather */ - rep(i, alpha) rep(j, alpha) { - mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( - alpha, oc - oc_start, oc, nr_units_in_tile, unit_idx, i, - j)]; - } - /* A[alpha*m] M[alpha*alpha] */ - megdnn::naive::run_matrix_mul_tpl( - winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, - alpha, m, alpha, alpha, dtype, dtype); - megdnn::naive::run_matrix_mul_tpl( - mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, - alpha, alpha, m, m, dtype, dtype); - - rep(i, m) rep(j, m) { - auto oh = oh_start + i; - auto ow = ow_start + j; - if (oh < OH && ow < OW) { - float val = mid_buf1[i * m + j]; - if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - val += bias[oc] * input_filter_rescale * - input_filter_rescale; - } else if (bmode == BiasMode::BIAS) { - val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * - input_filter_rescale * input_filter_rescale; - } - val = val * input_filter_scale / - (input_filter_rescale * input_filter_rescale * rescale * - rescale); - if (nonline_mode == NonlineMode::RELU) { - val = val > 0 ? val : 0; - } else if (nonline_mode == NonlineMode::SIGMOID) { - val = 1.f / (expf(-val) + 1.f); - } else if (nonline_mode == NonlineMode::H_SWISH) { - val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; - } else { - megdnn_assert(nonline_mode == NonlineMode::IDENTITY); - } - output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); + size_t oc = oc_start + oc_index; + + /* gather */ + rep(i, alpha) rep(j, alpha) { + mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( + alpha, oc_index, oc, nr_units_in_tile, unit_idx, i, + j)]; + } + /* A[alpha*m] M[alpha*alpha] */ + megdnn::naive::run_matrix_mul_tpl( + winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, + alpha, m, alpha, alpha, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl( + mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, + alpha, alpha, m, m, dtype, dtype); + + rep(i, m) rep(j, m) { + auto oh = oh_start + i; + auto ow = ow_start + j; + if (oh < OH && ow < OW) { + float val = mid_buf1[i * m + j]; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + val += bias[oc] * input_filter_rescale * + input_filter_rescale; + } else if (bmode == BiasMode::BIAS) { + val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * + input_filter_rescale * input_filter_rescale; + } + val = val * input_filter_scale / + (input_filter_rescale * input_filter_rescale * rescale * + rescale); + if (nonline_mode == NonlineMode::RELU) { + val = val > 0 ? val : 0; + } else if (nonline_mode == NonlineMode::SIGMOID) { + val = 1.f / (expf(-val) + 1.f); + } else if (nonline_mode == NonlineMode::H_SWISH) { + val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; + } else { + megdnn_assert(nonline_mode == NonlineMode::IDENTITY); } + output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); } } }; diff --git a/dnn/src/common/winograd/winograd_helper.h b/dnn/src/common/winograd/winograd_helper.h index c2cd945bc6a4c8062dd3fd6f4a62459b629ea00f..fd7d5ccf371e8d080a7bd7e2287567fbbb22d563 100644 --- a/dnn/src/common/winograd/winograd_helper.h +++ b/dnn/src/common/winograd/winograd_helper.h @@ -44,7 +44,7 @@ public: input_filter_compute_type* input_transform_buf, input_filter_compute_type* transform_mid_buf, int ih_start, int iw_start, size_t IH, size_t IW, - size_t IC, size_t unit_idx, size_t nr_units_in_tile, + size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, const std::vector& interp_points, DType dtype, float rescale = 1.0f); @@ -54,7 +54,7 @@ public: const output_compute_type* bias, dst_type* output, output_compute_type* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t oh_start, size_t ow_start, - size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t OH, size_t OW, size_t oc_start, size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, const std::vector& interp_points, DType dtype, float input_filter_scale = 1.0f, // input_scale * filter_scale diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 2af41765138d023b1925ce5d2da77c35702bb437..eba4f05aa5c4e80197ccbf218cccdf9506bbf16a 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -55,7 +55,7 @@ public: ohw_tile_size)); all_algos.emplace_back(refhold.back().get()); } -#if 0 +#if 1 //! As these algos maybe very slow, it will make fastrun search slow, so //! we disable it, but for the test of strategyhelper, we just keep it. //! FIXME: I do not know a better way to do it. diff --git a/dnn/src/fallback/conv_bias/winograd/strategy.cpp b/dnn/src/fallback/conv_bias/winograd/strategy.cpp index de0ff614b7d71ff2974d63d0b35efaa0f1bcc53e..f59fb04984ea05ce3eedd47e8e85175b48942f2b 100644 --- a/dnn/src/fallback/conv_bias/winograd/strategy.cpp +++ b/dnn/src/fallback/conv_bias/winograd/strategy.cpp @@ -6,8 +6,7 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/fallback/conv_bias/winograd/strategy.h" @@ -31,27 +30,54 @@ void winograd_2x3_1x1_f::filter(const float* filter, } void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, int ih_start, - int iw_start, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { - ::megdnn::winograd::StrategyHelper::input( - input, input_transform_buf, transform_mid_buf, ih_start, iw_start, - IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, - KERNEL_SIZE, {0, 1, -1}, src_dtype); + float* transform_mid_buf, size_t IH, size_t IW, + size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + ::megdnn::winograd::StrategyHelper:: + input(input, input_transform_buf, transform_mid_buf, + ih_start, iw_start, IH, IW, IC, ic, unit_idx, + nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, + {0, 1, -1}, src_dtype); + } + } } void winograd_2x3_1x1_f::output(const float* output_transform_buf, const float* bias, float* output, float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t unit_idx, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { - ::megdnn::winograd::StrategyHelper::output( - output_transform_buf, bias, output, transform_mid_buf, bmode, - nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, - unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, - {0, 1, -1}, dst_dtype); + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + size_t OC = oc_end - oc_start; + + for (size_t oc = oc_start; oc < oc_end; ++oc) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + ::megdnn::winograd::StrategyHelper:: + output(output_transform_buf, bias, output, + transform_mid_buf, bmode, nonline_mode, oh_start, + ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, + nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, + {0, 1, -1}, dst_dtype); + } + } } MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) @@ -71,38 +97,70 @@ void winograd_2x3_4x4_f::filter(const float* filter, } void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, int ih_start, - int iw_start, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { - ::megdnn::winograd::StrategyHelper< - float, float, float, float, param::ConvBias::Format::NCHW, - param::MatrixMul::Format::MK4>::input(input, input_transform_buf, - transform_mid_buf, ih_start, - iw_start, IH, IW, IC, - unit_idx, nr_units_in_tile, - OUTPUT_BLOCK_SIZE, - KERNEL_SIZE, {0, 1, -1}, - src_dtype); + float* transform_mid_buf, size_t IH, size_t IW, + size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, size_t nr_units_in_tile) { + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + ::megdnn::winograd::StrategyHelper< + float, float, float, float, param::ConvBias::Format::NCHW, + param::MatrixMul::Format::MK4>::input(input, + input_transform_buf, + transform_mid_buf, + ih_start, iw_start, + IH, IW, IC, ic, + unit_idx, + nr_units_in_tile, + OUTPUT_BLOCK_SIZE, + KERNEL_SIZE, + {0, 1, -1}, + src_dtype); + } + } } void winograd_2x3_4x4_f::output(const float* output_transform_buf, const float* bias, float* output, float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t unit_idx, + NonlineMode nonline_mode, size_t OH, size_t OW, + size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { - ::megdnn::winograd::StrategyHelper< - float, float, float, float, param::ConvBias::Format::NCHW, - param::MatrixMul::Format::MK4>::output(output_transform_buf, bias, - output, transform_mid_buf, - bmode, nonline_mode, - oh_start, ow_start, OH, OW, - oc_start, oc_end, unit_idx, - nr_units_in_tile, - OUTPUT_BLOCK_SIZE, - KERNEL_SIZE, {0, 1, -1}, - dst_dtype); + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + size_t OC = oc_end - oc_start; + + for (size_t oc = oc_start; oc < oc_end; ++oc) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + ::megdnn::winograd::StrategyHelper< + float, float, float, float, param::ConvBias::Format::NCHW, + param::MatrixMul::Format::MK4>::output(output_transform_buf, + bias, output, + transform_mid_buf, + bmode, nonline_mode, + oh_start, ow_start, + OH, OW, OC, oc_start, + oc_index, unit_idx, + nr_units_in_tile, + OUTPUT_BLOCK_SIZE, + KERNEL_SIZE, + {0, 1, -1}, + dst_dtype); + } + } } MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) @@ -119,29 +177,59 @@ void winograd_2x3_1x1_qs8::filter(const int8_t* filter, void winograd_2x3_1x1_qs8::input(const int8_t* input, int16_t* input_transform_buf, - int16_t* transform_mid_buf, int ih_start, - int iw_start, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { - ::megdnn::winograd::StrategyHelper::input( - input, input_transform_buf, transform_mid_buf, ih_start, iw_start, - IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, - KERNEL_SIZE, {0, 1, -1}, src_dtype, 1.0f); + int16_t* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, + size_t nr_units_in_tile) { + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + ::megdnn::winograd::StrategyHelper:: + input(input, input_transform_buf, transform_mid_buf, + ih_start, iw_start, IH, IW, IC, ic, unit_idx, + nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, + {0, 1, -1}, src_dtype, 1.0f); + } + } } void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, const int* bias, int8_t* output, int* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_idx, size_t nr_units_in_tile) { + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, + size_t nr_units_in_tile) { float scale_input = src_dtype.param().scale; float scale_filter = filter_dtype.param().scale; - ::megdnn::winograd::StrategyHelper::output( - output_transform_buf, bias, output, transform_mid_buf, bmode, - nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, - unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, - {0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, 1.0f); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + size_t OC = oc_end - oc_start; + + for (size_t oc = oc_start; oc < oc_end; ++oc) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + ::megdnn::winograd::StrategyHelper:: + output(output_transform_buf, bias, output, + transform_mid_buf, bmode, nonline_mode, oh_start, + ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, + nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, + {0, 1, -1}, dst_dtype, scale_input * scale_filter, + 2.0f, 1.0f); + } + } } MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) @@ -162,27 +250,44 @@ void winograd_2x3_8x8_qs8::filter(const int8_t* filter, void winograd_2x3_8x8_qs8::input(const int8_t* input, int16_t* input_transform_buf, - int16_t* transform_mid_buf, int ih_start, - int iw_start, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { - ::megdnn::winograd::StrategyHelper< - int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, - param::MatrixMul::Format::MK8>::input(input, input_transform_buf, - transform_mid_buf, ih_start, - iw_start, IH, IW, IC, - unit_idx, nr_units_in_tile, - OUTPUT_BLOCK_SIZE, - KERNEL_SIZE, {0, 1, -1}, - src_dtype, 1.0f); + int16_t* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, size_t PW, + size_t unit_start_idx, + size_t nr_units_in_tile) { + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + ::megdnn::winograd::StrategyHelper< + int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, + param::MatrixMul::Format::MK8>::input(input, + input_transform_buf, + transform_mid_buf, + ih_start, iw_start, + IH, IW, IC, ic, + unit_idx, + nr_units_in_tile, + OUTPUT_BLOCK_SIZE, + KERNEL_SIZE, + {0, 1, -1}, src_dtype, + 1.0f); + } + } } void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, const int* bias, int8_t* output, int* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_idx, size_t nr_units_in_tile) { + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, + size_t nr_units_in_tile) { float scale_input = src_dtype.param().scale; float scale_filter = 0.f; if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { @@ -191,19 +296,37 @@ void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); scale_filter = filter_dtype.param().scale; } - ::megdnn::winograd::StrategyHelper< - int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, - param::MatrixMul::Format::MK8>::output(output_transform_buf, bias, - output, transform_mid_buf, - bmode, nonline_mode, - oh_start, ow_start, OH, OW, - oc_start, oc_end, unit_idx, - nr_units_in_tile, - OUTPUT_BLOCK_SIZE, - KERNEL_SIZE, {0, 1, -1}, - dst_dtype, - scale_input * scale_filter, - 2.0f, 1.0f); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + size_t OC = oc_end - oc_start; + + for (size_t oc = oc_start; oc < oc_end; ++oc) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + ::megdnn::winograd::StrategyHelper< + int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, + param::MatrixMul::Format::MK8>::output(output_transform_buf, + bias, output, + transform_mid_buf, + bmode, nonline_mode, + oh_start, ow_start, + OH, OW, OC, oc_start, + oc_index, unit_idx, + nr_units_in_tile, + OUTPUT_BLOCK_SIZE, + KERNEL_SIZE, + {0, 1, -1}, + dst_dtype, + scale_input * + scale_filter, + 2.0f, 1.0f); + } + } } } // namespace winograd diff --git a/dnn/src/fallback/conv_bias/winograd/winograd.h b/dnn/src/fallback/conv_bias/winograd/winograd.h index 56ab3d98769e717b4cdd0b71f62b9287c76b191e..03bfa3d24bd361fa453c8905181b1a96fae8b46f 100644 --- a/dnn/src/fallback/conv_bias/winograd/winograd.h +++ b/dnn/src/fallback/conv_bias/winograd/winograd.h @@ -321,17 +321,10 @@ public: "nr_tiles_in_unit: %zu TILE_SIZE:%zu", nr_tiles_in_unit, unit_tile_size); } - rep(unit_idx, nr_tiles_in_unit) { - size_t index = unit_start_idx + unit_idx; - size_t nh = index / units_w; - size_t nw = index % units_w; - int ih_start = nh * Strategy::OUTPUT_BLOCK_SIZE - PH; - int iw_start = nw * Strategy::OUTPUT_BLOCK_SIZE - PW; - - strategy.input(src_ptr, input_transform_buf, transform_mid_buf, - ih_start, iw_start, IH, IW, IC, unit_idx, - nr_tiles_in_unit); - } + //! BTdB + strategy.input(src_ptr, input_transform_buf, transform_mid_buf, + IH, IW, IC, PH, PW, unit_start_idx, nr_tiles_in_unit); + rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) { if (format == param::MatrixMul::Format::DEFAULT) { matmul_param.A_ptr = @@ -368,22 +361,14 @@ public: } matmul_kern(matmul_param); } - /* Y = ATmA */ - rep(unit_idx, nr_tiles_in_unit) { - size_t index = unit_start_idx + unit_idx; - auto nh = index / units_w; - auto nw = index % units_w; - size_t oh_start = nh * Strategy::OUTPUT_BLOCK_SIZE; - size_t ow_start = nw * Strategy::OUTPUT_BLOCK_SIZE; - size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; - - strategy.output( - output_transform_buf, bias_ptr, dst_ptr, - reinterpret_cast(transform_mid_buf), - ncb_param.bias_mode, ncb_param.nonlineMode, oh_start, - ow_start, OH, OW, oc_start_idx, oc_end_idx, unit_idx, - nr_tiles_in_unit); - } + + //! Y = ATmA + size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; + strategy.output( + output_transform_buf, bias_ptr, dst_ptr, + reinterpret_cast(transform_mid_buf), + ncb_param.bias_mode, ncb_param.nonlineMode, OH, OW, + oc_start_idx, oc_end_idx, unit_start_idx, nr_tiles_in_unit); }; SmallVector get_kerns( @@ -542,15 +527,16 @@ public: size_t IC, size_t oc_start, size_t oc_end); \ void input(const stype* input, \ input_filter_compute_type* input_transform_buf, \ - input_filter_compute_type* transform_mid_buf, int ih_start, \ - int iw_start, size_t IH, size_t IW, size_t IC, \ - size_t unit_idx, size_t nr_tiles_in_unit); \ + input_filter_compute_type* transform_mid_buf, \ + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, \ + size_t unit_start_idx, size_t nr_tiles_in_unit); \ void output(const output_compute_type* output_transform_buf, \ const output_compute_type* bias, dst_type* output, \ output_compute_type* transform_mid_buf, BiasMode bmode, \ - NonlineMode nonline_mode, size_t oh_start, \ - size_t ow_start, size_t OH, size_t OW, size_t oc_start, \ - size_t oc_end, size_t unit_idx, size_t nr_tiles_in_unit); \ + NonlineMode nonline_mode, size_t OH, size_t OW, \ + size_t oc_start, size_t oc_end, size_t unit_start_idx, \ + size_t nr_tiles_in_unit); \ + }; #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ diff --git a/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp b/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp index 9b3fb6f202927938f6fd4ae07a4de4a44278f372..54e68b87ec296405d12dd2704961bc62ffeacd7f 100644 --- a/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp +++ b/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp @@ -274,31 +274,43 @@ void winograd_nchw88_2x3_8x8_f::filter(const float* filter, transform_mid_buf, OC, IC, oc_start, oc_end); } + void winograd_nchw88_2x3_8x8_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, int ih_start, - int iw_start, size_t IH, size_t IW, - size_t IC, size_t unit_idx, + float* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { megdnn_assert(IC % 8 == 0); + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); float* patch = transform_mid_buf; float* patchT = transform_mid_buf + 8 * alpha * alpha; - if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && - iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - for (size_t ic = 0; ic < IC; ic += 8) { - InputTransform2X3_NCHW88::prepare( - input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); - InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); - } - } else { - for (size_t ic = 0; ic < IC; ic += 8) { - InputTransform2X3_NCHW88::prepare(input, patch, patchT, ih_start, - iw_start, IH, IW, ic, IC); - InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + + for (size_t ic = 0; ic < IC; ic += 8) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform2X3_NCHW88::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + } else { + InputTransform2X3_NCHW88::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + } } } } diff --git a/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp b/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp index 44a9d3e89c38b4c193cbf486cc9a05470755dd81..6d5ab52b117fbbe436ce56b62953329fe6e3879a 100644 --- a/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp +++ b/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp @@ -338,32 +338,43 @@ void winograd_nchw88_6x3_8x8_f::filter(const float* filter, transform_mid_buf, OC, IC, oc_start, oc_end); } + void winograd_nchw88_6x3_8x8_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, int ih_start, - int iw_start, size_t IH, size_t IW, - size_t IC, size_t unit_idx, + float* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { megdnn_assert(IC % 8 == 0); + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); float* patch = transform_mid_buf; float* patchT = transform_mid_buf + 8 * alpha * alpha; - if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && - iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - for (size_t ic = 0; ic < IC; ic += 8) { - InputTransform6X3_NCHW88::prepare( - input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); - InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); - } - } else { - for (size_t ic = 0; ic < IC; ic += 8) { - InputTransform6X3_NCHW88::prepare(input, patch, patchT, ih_start, - iw_start, IH, IW, ic, IC); - InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + + for (size_t ic = 0; ic < IC; ic += 8) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform6X3_NCHW88::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + } else { + InputTransform6X3_NCHW88::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + } } } }