diff --git a/dnn/src/common/winograd/winograd_helper.cpp b/dnn/src/common/winograd/winograd_helper.cpp index c950cf7b532588165097e567906f31059dbed3e8..15e3c5884001c71a0b3db73dd3a6bacc707fe20c 100644 --- a/dnn/src/common/winograd/winograd_helper.cpp +++ b/dnn/src/common/winograd/winograd_helper.cpp @@ -235,7 +235,7 @@ void StrategyHelper< input_filter_compute_type* input_transform_buf, input_filter_compute_type* transform_mid_buf, int ih_start, int iw_start, size_t IH, size_t IW, - size_t IC, size_t unit_idx, size_t nr_units_in_tile, + size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, const std::vector& interp_points, DType dtype, float rescale) { @@ -284,7 +284,7 @@ void StrategyHelper< const output_compute_type* bias, dst_type* output, output_compute_type* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, size_t oc_start, + size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, const std::vector& interp_points, DType dtype, @@ -296,7 +296,7 @@ void StrategyHelper< output_compute_type* mid_buf1 = transform_mid_buf; output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; OutputGetter getter(dtype); - OutputVisitor output_visitor(oc_end - oc_start); + OutputVisitor output_visitor(OC); size_t oc = oc_start + oc_index; diff --git a/dnn/src/common/winograd/winograd_helper.h b/dnn/src/common/winograd/winograd_helper.h index fd7d5ccf371e8d080a7bd7e2287567fbbb22d563..017b70ab73ad529c69e5c5bc20f343e297f364fb 100644 --- a/dnn/src/common/winograd/winograd_helper.h +++ b/dnn/src/common/winograd/winograd_helper.h @@ -6,8 +6,7 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once @@ -44,8 +43,8 @@ public: input_filter_compute_type* input_transform_buf, input_filter_compute_type* transform_mid_buf, int ih_start, int iw_start, size_t IH, size_t IW, - size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, - size_t m, size_t r, + size_t IC, size_t ic, size_t unit_idx, + size_t nr_units_in_tile, size_t m, size_t r, const std::vector& interp_points, DType dtype, float rescale = 1.0f); @@ -54,7 +53,7 @@ public: const output_compute_type* bias, dst_type* output, output_compute_type* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t oh_start, size_t ow_start, - size_t OH, size_t OW, size_t oc_start, size_t oc_index, + size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, const std::vector& interp_points, DType dtype, float input_filter_scale = 1.0f, // input_scale * filter_scale diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index eba4f05aa5c4e80197ccbf218cccdf9506bbf16a..c4e5a7aca2609ccae794824d3c210af7fb4fcf2a 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -45,7 +45,6 @@ public: static_cast(matmul_opr)->algo_pack(); for (auto&& algo : matmul_algos) { if (algo->algoset() == - //! TODO: threre should filter MK matmul MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { continue; } diff --git a/dnn/src/fallback/conv_bias/winograd/winograd.h b/dnn/src/fallback/conv_bias/winograd/winograd.h index 03bfa3d24bd361fa453c8905181b1a96fae8b46f..f4d4f25849e8f136ac11418f76512de4372b73af 100644 --- a/dnn/src/fallback/conv_bias/winograd/winograd.h +++ b/dnn/src/fallback/conv_bias/winograd/winograd.h @@ -536,7 +536,6 @@ public: NonlineMode nonline_mode, size_t OH, size_t OW, \ size_t oc_start, size_t oc_end, size_t unit_start_idx, \ size_t nr_tiles_in_unit); \ - }; #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ diff --git a/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp b/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp index 54e68b87ec296405d12dd2704961bc62ffeacd7f..7f9d0f4091a01cd69a9e268ea4a2d046464d5c99 100644 --- a/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp +++ b/dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp @@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 { float* output, float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& dst_dtype) { + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { MEGDNN_MARK_USED_VAR(transform_mid_buf); - megdnn_assert( - (oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && - oc_end % 8 == 0, - "Winograd output transform input param is not times of 8!"); Op op(src_dtype, dst_dtype); //! AT * m * A size_t OCB = (oc_end - oc_start) / 8; - for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { - size_t ocb = (oc - oc_start) / 8; + size_t oc = oc_start + oc_index; + size_t ocb = oc_index / 8; + #define cb(m, n) \ auto v##m##n = Vector::load( \ output_transform_buf + \ (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ ocb * nr_units_in_tile * 8 + unit_idx * 8); - UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb - //! 1 1 1 0 v00 v01 v02 v03 1 0 - //! 0 1 -1 1 v10 v11 v12 v13 1 1 - //! v20 v21 v22 v23 1 -1 - //! v30 v31 v32 v33 0 1 + //! 1 1 1 0 v00 v01 v02 v03 1 0 + //! 0 1 -1 1 v10 v11 v12 v13 1 1 + //! v20 v21 v22 v23 1 -1 + //! v30 v31 v32 v33 0 1 #define cb(m) \ auto t0##m = v0##m + v1##m + v2##m; \ auto t1##m = v1##m - v2##m + v3##m; - UNROLL_CALL_NOWRAPPER(4, cb); + UNROLL_CALL_NOWRAPPER(4, cb); #undef cb #define cb(m) \ v##m##0 = t##m##0 + t##m##1 + t##m##2; \ v##m##1 = t##m##1 - t##m##2 + t##m##3; - UNROLL_CALL_NOWRAPPER(2, cb); + UNROLL_CALL_NOWRAPPER(2, cb); #undef cb - Vector vbias; - if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - vbias = Vector::load(bias + oc); + Vector vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector::load(bias + oc); #define cb(m, n) v##m##n += vbias; - UNROLL_CALL_RAW_D2(2, 2, cb); + UNROLL_CALL_RAW_D2(2, 2, cb); #undef cb - } - if (bmode != BiasMode::BIAS) { + } + if (bmode != BiasMode::BIAS) { #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); - UNROLL_CALL_RAW_D2(2, 2, cb); + UNROLL_CALL_RAW_D2(2, 2, cb); #undef cb - } + } #define out_save(oho, owo) \ do { \ size_t oh = oh_start + oho; \ @@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 { ow * 8); \ } \ } while (0); - UNROLL_CALL_RAW_D2(2, 2, out_save); - } + UNROLL_CALL_RAW_D2(2, 2, out_save); } }; #undef CONCAT @@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input, } } -void winograd_nchw88_2x3_8x8_f::output( - const float* output_transform_buf, const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, - size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, - size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { +void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, + size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ __VA_ARGS__); - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float, - float, bmode, nonline_mode, output_transform_buf, bias, output, - transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, - unit_idx, nr_units_in_tile, src_dtype, dst_dtype); + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + size_t OC = oc_end - oc_start; + + megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, + "Winograd output transform input param is not times of 8!"); + + for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, + float, float, bmode, nonline_mode, output_transform_buf, + bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, + oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, + dst_dtype); + } + } #undef cb } diff --git a/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp b/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp index 6d5ab52b117fbbe436ce56b62953329fe6e3879a..280d10a981ea345bfd8462db6dbf9d145e04b18c 100644 --- a/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp +++ b/dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/common/unroll_macro.h" @@ -19,10 +20,10 @@ #include #ifdef WIN32CMAKE -#include -#include #include +#include #include +#include #endif #include "midout.h" @@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 { int ih_start, int iw_start, size_t IH, size_t IW, size_t ic, size_t IC) { MEGDNN_MARK_USED_VAR(patch); - size_t IW8 = IW * 8; //! For nchw88 mode + size_t IW8 = IW * 8; //! For nchw88 mode size_t iw8_start = iw_start * 8; //! For nchw88 mode size_t icb = ic / 8; if (!(inner && ic + 8 < IC)) { @@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 { for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { for (size_t icb = 0; icb < ICB; icb++) { - for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){ + for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) { const float* fptr = filter + (ocb * ICB + icb) * 3 * 3 * 8 * 8 + ic_inner * 8; @@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 { float* output, float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& dst_dtype) { + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { MEGDNN_MARK_USED_VAR(transform_mid_buf); - megdnn_assert( - (oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && - oc_end % 8 == 0, - "Winograd output transform input param is not times of 8!"); + Op op(src_dtype, dst_dtype); //! AT * m * A size_t OCB = (oc_end - oc_start) / 8; - for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { - size_t ocb = (oc - oc_start) / 8; + size_t oc = oc_start + oc_index; + size_t ocb = oc_index / 8; + #define cb(m, n) \ auto v##m##n = Vector::load( \ output_transform_buf + \ (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ ocb * nr_units_in_tile * 8 + unit_idx * 8); - UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb - /** - * A - * - * 1 0 0 0 0 0 - * 1 1 1 1 1 1 - * 1 -1 1 -1 1 -1 - * 1 2 4 8 16 32 - * 1 -2 4 -8 16 -32 - * 1 0.5 0.25 0.125 0.0625 0.03125 - * 1 -0.5 0.25 -0.125 0.0625 -0.03125 - * 0 0.0 0 0 0 1 - */ - - Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, - v5subv6; + /** + * A + * + * 1 0 0 0 0 0 + * 1 1 1 1 1 1 + * 1 -1 1 -1 1 -1 + * 1 2 4 8 16 32 + * 1 -2 4 -8 16 -32 + * 1 0.5 0.25 0.125 0.0625 0.03125 + * 1 -0.5 0.25 -0.125 0.0625 -0.03125 + * 0 0.0 0 0 0 1 + */ + + Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; #define cb(m) \ v1addv2 = v1##m + v2##m; \ v1subv2 = v1##m - v2##m; \ @@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 { auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; - UNROLL_CALL_NOWRAPPER(8, cb); + UNROLL_CALL_NOWRAPPER(8, cb); #undef cb #define cb(m) \ @@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 { v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; - UNROLL_CALL_NOWRAPPER(6, cb); + UNROLL_CALL_NOWRAPPER(6, cb); #undef cb - Vector vbias; - if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - vbias = Vector::load(bias + oc); + Vector vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector::load(bias + oc); #define cb(m, n) v##m##n += vbias; - UNROLL_CALL_RAW_D2(6, 6, cb); + UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb - } - if (bmode != BiasMode::BIAS) { + } + if (bmode != BiasMode::BIAS) { #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); - UNROLL_CALL_RAW_D2(6, 6, cb); + UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb - } + } #define out_save(oho, owo) \ do { \ size_t oh = oh_start + oho; \ @@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 { ow * 8); \ } \ } while (0); - UNROLL_CALL_RAW_D2(6, 6, out_save); - } + UNROLL_CALL_RAW_D2(6, 6, out_save); } }; #undef CONCAT @@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, megdnn_assert(IC % 8 == 0); // OW = IW + 2 * PW - KERNEL_SIZE + 1 - auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); float* patch = transform_mid_buf; float* patchT = transform_mid_buf + 8 * alpha * alpha; @@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, } } -void winograd_nchw88_6x3_8x8_f::output( - const float* output_transform_buf, const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, - size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, - size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { +void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, + size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ __VA_ARGS__); - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float, - float, bmode, nonline_mode, output_transform_buf, bias, output, - transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, - unit_idx, nr_units_in_tile, src_dtype, dst_dtype); + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + size_t OC = oc_end - oc_start; + + megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, + "Winograd output transform input param is not times of 8!"); + + for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, + float, float, bmode, nonline_mode, output_transform_buf, + bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, + oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, + src_dtype, dst_dtype); + } + } #undef cb } } // namespace winograd -} // namespace arm_common +} // namespace x86 } // namespace megdnn // vim: syntax=cpp.doxygen