diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index e7738120c6b01f26f074fefc80aa8ac8671321f4..f7edf5956b620176effeddf0d7ee3d316ae51775 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -154,7 +154,7 @@ public: for (auto&& algo : matmul_algos) { if (algo->type() == nullptr) continue; - for (uint32_t tile_size : {8, 16, 24, 32, 40, 48, 64, 80}) { + for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoFP32WinogradF23_4x4( static_cast(algo), tile_size)); diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 6f5e9e81bce3c0feb8210f11774a48d1d422a7df..a49f824737d919443c6404cfce9e6d05798302da 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -725,6 +725,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); + cb(nchw4, { add_pass(); add_pass(); @@ -736,10 +737,21 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( add_pass(); add_pass(ConvertFormatPass::make_nhwcd4_converter()); }); - cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); - cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); - cb(nchw44_dot, - { add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); }); + cb(nchw88, { + add_pass(); + add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); + add_pass(); + }); + cb(nchw44, { + add_pass(); + add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); + add_pass(); + }); + cb(nchw44_dot, { + add_pass(); + add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); + add_pass(); + }); cb(nchw32, { add_pass(); add_pass(); diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 9f72c3f71fbd3cc29abe1dc2064b6926bd270301..6db500041653d34ae69dde862a6d57c7669017fa 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -707,7 +707,9 @@ template <> void AlgoChooser::ExeContext:: modify_param_with_weights_preprocessed( typename TimedProfiler::Param& param) const { - if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) { + if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW || + param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44 || + param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW88) { auto winograd_param = megdnn::ConvBias::parse_winograd_name(param.algo_name); if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) { @@ -727,8 +729,18 @@ void AlgoChooser::ExeContext:: filter_transform_layout); param.shapes[1] = filter_transform_layout; param.dtypes[1] = filter_transform_layout.dtype.enumv(); - - param.opr_param.format = megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; + if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) { + param.opr_param.format = + megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; + } else if (param.opr_param.format == + megdnn::ConvBias::Param::Format::NCHW44) { + param.opr_param.format = + megdnn::ConvBias::Param::Format::NCHW44_WINOGRAD; + } else if (param.opr_param.format == + megdnn::ConvBias::Param::Format::NCHW) { + param.opr_param.format = + megdnn::ConvBias::Param::Format::NCHW88_WINOGRAD; + } param.opr_param.output_block_size = winograd_param.output_block_size; } } diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index eaabb8cf83da38cc1de43c1f7fabf55419d278b6..22111a7dbf7d07f181c770d1fab8d3ffbd8ef431 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -160,6 +160,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, spatial_start = 2; break; case Param::Format::NCHW_WINOGRAD: + case Param::Format::NCHW44_WINOGRAD: case Param::Format::NCHW88_WINOGRAD: cpos = 1; spatial_start = 0; @@ -191,9 +192,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, uint64_t fh = static_cast(filter_shape[spatial_start]); uint64_t fw = static_cast(filter_shape[spatial_start + 1]); if (param.format == Param::Format::NCHW_WINOGRAD || + param.format == Param::Format::NCHW44_WINOGRAD || param.format == Param::Format::NCHW88_WINOGRAD) { mgb_assert(opr->same_type(), - "Only conv bias support NCHW_WINOGRAD"); + "Only conv bias support WINOGRAD"); auto&& conv_bias_opr = opr->cast_final_safe(); uint32_t output_block_size = conv_bias_opr.param().output_block_size; mgb_assert(fh == fw, @@ -208,6 +210,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, return dst_shape.total_nr_elems() * fh * fw * static_cast(src_shape[cpos] * 8) / group * 2; } + if (param.format == Param::Format::NCHW44_WINOGRAD) { + return dst_shape.total_nr_elems() * fh * fw * + static_cast(src_shape[cpos] * 4) / group * 2; + } return dst_shape.total_nr_elems() * fh * fw * static_cast(src_shape[cpos]) / group * 2; }