提交 65855149 编写于 作者: M Megvii Engine Team

fix(gopt): fix convbias replace of cd4 pass

GitOrigin-RevId: b0715e2b77026aaec4b7348fda0ae6fa282778b9
上级 36f17dec
......@@ -716,8 +716,8 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options(
need_param_fuse = true;
}
if (options.transform_nhwcd4()) {
add_pass(ConvertFormatPass::make_nhwcd4_converter());
add_pass<FuseConvBiasNonlinPass>();
add_pass(ConvertFormatPass::make_nhwcd4_converter());
need_param_fuse = true;
}
if (options.transform_nchw88()) {
......
......@@ -1169,18 +1169,33 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
conv_bias_weights = relayout_weight.node();
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
conv_bias_bias = relayout_bias.node();
mgb_assert(new_inp.size() < 4,
"ConvertFormat pass does not support fuse Z");
bool has_bias = new_inp.size() > 2;
if (has_bias &&
new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) {
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
conv_bias_bias = relayout_bias.node();
} else if (has_bias) {
conv_bias_bias = new_inp[2];
}
auto new_param = conv_bias_opr.param();
new_param.format = megdnn::param::ConvBias::Format::NHWCD4;
mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_src->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4);
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
SymbolVar new_conv_bias_opr;
if (has_bias) {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
}
OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5 &&
new_conv_bias_opr.format().type() ==
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册