提交 18be23f3 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mbg/gopt): fix nchwxx gopt with no fuse conv_bias and winograd

fast-run

GitOrigin-RevId: 49ccbdf2d43229f3883af91f9f9641f695e2a799
上级 38f7cbd9
...@@ -154,7 +154,7 @@ public: ...@@ -154,7 +154,7 @@ public:
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr) if (algo->type() == nullptr)
continue; continue;
for (uint32_t tile_size : {8, 16, 24, 32, 40, 48, 64, 80}) { for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
......
...@@ -725,6 +725,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( ...@@ -725,6 +725,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
cb(nchw4, { cb(nchw4, {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
...@@ -736,10 +737,21 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( ...@@ -736,10 +737,21 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass(ConvertFormatPass::make_nhwcd4_converter()); add_pass(ConvertFormatPass::make_nhwcd4_converter());
}); });
cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); cb(nchw88, {
cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); add_pass<FuseConvBiasNonlinPass>();
cb(nchw44_dot, add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
{ add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); }); add_pass<ShuffleShuffleRemovePass>();
});
cb(nchw44, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
add_pass<ShuffleShuffleRemovePass>();
});
cb(nchw44_dot, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(EnableNchw44DotPass::make_nchw44_dot_converter());
add_pass<ShuffleShuffleRemovePass>();
});
cb(nchw32, { cb(nchw32, {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
......
...@@ -707,7 +707,9 @@ template <> ...@@ -707,7 +707,9 @@ template <>
void AlgoChooser<megdnn::ConvBias>::ExeContext:: void AlgoChooser<megdnn::ConvBias>::ExeContext::
modify_param_with_weights_preprocessed( modify_param_with_weights_preprocessed(
typename TimedProfiler<megdnn::ConvBias>::Param& param) const { typename TimedProfiler<megdnn::ConvBias>::Param& param) const {
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) { if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW ||
param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44 ||
param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW88) {
auto winograd_param = auto winograd_param =
megdnn::ConvBias::parse_winograd_name(param.algo_name); megdnn::ConvBias::parse_winograd_name(param.algo_name);
if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) { if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) {
...@@ -727,8 +729,18 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext:: ...@@ -727,8 +729,18 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext::
filter_transform_layout); filter_transform_layout);
param.shapes[1] = filter_transform_layout; param.shapes[1] = filter_transform_layout;
param.dtypes[1] = filter_transform_layout.dtype.enumv(); param.dtypes[1] = filter_transform_layout.dtype.enumv();
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) {
param.opr_param.format = megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; param.opr_param.format =
megdnn::ConvBias::Param::Format::NCHW_WINOGRAD;
} else if (param.opr_param.format ==
megdnn::ConvBias::Param::Format::NCHW44) {
param.opr_param.format =
megdnn::ConvBias::Param::Format::NCHW44_WINOGRAD;
} else if (param.opr_param.format ==
megdnn::ConvBias::Param::Format::NCHW) {
param.opr_param.format =
megdnn::ConvBias::Param::Format::NCHW88_WINOGRAD;
}
param.opr_param.output_block_size = winograd_param.output_block_size; param.opr_param.output_block_size = winograd_param.output_block_size;
} }
} }
......
...@@ -160,6 +160,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, ...@@ -160,6 +160,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
spatial_start = 2; spatial_start = 2;
break; break;
case Param::Format::NCHW_WINOGRAD: case Param::Format::NCHW_WINOGRAD:
case Param::Format::NCHW44_WINOGRAD:
case Param::Format::NCHW88_WINOGRAD: case Param::Format::NCHW88_WINOGRAD:
cpos = 1; cpos = 1;
spatial_start = 0; spatial_start = 0;
...@@ -191,9 +192,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, ...@@ -191,9 +192,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]); uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]);
uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]); uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]);
if (param.format == Param::Format::NCHW_WINOGRAD || if (param.format == Param::Format::NCHW_WINOGRAD ||
param.format == Param::Format::NCHW44_WINOGRAD ||
param.format == Param::Format::NCHW88_WINOGRAD) { param.format == Param::Format::NCHW88_WINOGRAD) {
mgb_assert(opr->same_type<opr::ConvBias>(), mgb_assert(opr->same_type<opr::ConvBias>(),
"Only conv bias support NCHW_WINOGRAD"); "Only conv bias support WINOGRAD");
auto&& conv_bias_opr = opr->cast_final_safe<opr::ConvBias>(); auto&& conv_bias_opr = opr->cast_final_safe<opr::ConvBias>();
uint32_t output_block_size = conv_bias_opr.param().output_block_size; uint32_t output_block_size = conv_bias_opr.param().output_block_size;
mgb_assert(fh == fw, mgb_assert(fh == fw,
...@@ -208,6 +210,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, ...@@ -208,6 +210,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
return dst_shape.total_nr_elems() * fh * fw * return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos] * 8) / group * 2; static_cast<uint64_t>(src_shape[cpos] * 8) / group * 2;
} }
if (param.format == Param::Format::NCHW44_WINOGRAD) {
return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos] * 4) / group * 2;
}
return dst_shape.total_nr_elems() * fh * fw * return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos]) / group * 2; static_cast<uint64_t>(src_shape[cpos]) / group * 2;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册