提交 18be23f3 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mbg/gopt): fix nchwxx gopt with no fuse conv_bias and winograd

fast-run

GitOrigin-RevId: 49ccbdf2d43229f3883af91f9f9641f695e2a799
上级 38f7cbd9
......@@ -154,7 +154,7 @@ public:
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {8, 16, 24, 32, 40, 48, 64, 80}) {
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
......
......@@ -725,6 +725,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
cb(nchw4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
......@@ -736,10 +737,21 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass<FuseConvBiasNonlinPass>();
add_pass(ConvertFormatPass::make_nhwcd4_converter());
});
cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); });
cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); });
cb(nchw44_dot,
{ add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); });
cb(nchw88, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
add_pass<ShuffleShuffleRemovePass>();
});
cb(nchw44, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
add_pass<ShuffleShuffleRemovePass>();
});
cb(nchw44_dot, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(EnableNchw44DotPass::make_nchw44_dot_converter());
add_pass<ShuffleShuffleRemovePass>();
});
cb(nchw32, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
......
......@@ -707,7 +707,9 @@ template <>
void AlgoChooser<megdnn::ConvBias>::ExeContext::
modify_param_with_weights_preprocessed(
typename TimedProfiler<megdnn::ConvBias>::Param& param) const {
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) {
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW ||
param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44 ||
param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW88) {
auto winograd_param =
megdnn::ConvBias::parse_winograd_name(param.algo_name);
if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) {
......@@ -727,8 +729,18 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext::
filter_transform_layout);
param.shapes[1] = filter_transform_layout;
param.dtypes[1] = filter_transform_layout.dtype.enumv();
param.opr_param.format = megdnn::ConvBias::Param::Format::NCHW_WINOGRAD;
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) {
param.opr_param.format =
megdnn::ConvBias::Param::Format::NCHW_WINOGRAD;
} else if (param.opr_param.format ==
megdnn::ConvBias::Param::Format::NCHW44) {
param.opr_param.format =
megdnn::ConvBias::Param::Format::NCHW44_WINOGRAD;
} else if (param.opr_param.format ==
megdnn::ConvBias::Param::Format::NCHW) {
param.opr_param.format =
megdnn::ConvBias::Param::Format::NCHW88_WINOGRAD;
}
param.opr_param.output_block_size = winograd_param.output_block_size;
}
}
......
......@@ -160,6 +160,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
spatial_start = 2;
break;
case Param::Format::NCHW_WINOGRAD:
case Param::Format::NCHW44_WINOGRAD:
case Param::Format::NCHW88_WINOGRAD:
cpos = 1;
spatial_start = 0;
......@@ -191,9 +192,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]);
uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]);
if (param.format == Param::Format::NCHW_WINOGRAD ||
param.format == Param::Format::NCHW44_WINOGRAD ||
param.format == Param::Format::NCHW88_WINOGRAD) {
mgb_assert(opr->same_type<opr::ConvBias>(),
"Only conv bias support NCHW_WINOGRAD");
"Only conv bias support WINOGRAD");
auto&& conv_bias_opr = opr->cast_final_safe<opr::ConvBias>();
uint32_t output_block_size = conv_bias_opr.param().output_block_size;
mgb_assert(fh == fw,
......@@ -208,6 +210,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos] * 8) / group * 2;
}
if (param.format == Param::Format::NCHW44_WINOGRAD) {
return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos] * 4) / group * 2;
}
return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos]) / group * 2;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册