diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index ce22d07615391478f6b17fc87230dec023fb5cb7..aa287d37220e4a4ae036bb147ad5acf09ac6a20a 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -716,8 +716,8 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( need_param_fuse = true; } if (options.transform_nhwcd4()) { - add_pass(ConvertFormatPass::make_nhwcd4_converter()); add_pass(); + add_pass(ConvertFormatPass::make_nhwcd4_converter()); need_param_fuse = true; } if (options.transform_nchw88()) { diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 87626537eb10d935235111bb52a797e97edae801..cb1ae70c29cac533776272166ac4a9d45f05319c 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1169,18 +1169,33 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); conv_bias_weights = relayout_weight.node(); - param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; - auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param); - conv_bias_bias = relayout_bias.node(); + mgb_assert(new_inp.size() < 4, + "ConvertFormat pass does not support fuse Z"); + bool has_bias = new_inp.size() > 2; + if (has_bias && + new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) { + param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; + auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param); + conv_bias_bias = relayout_bias.node(); + } else if (has_bias) { + conv_bias_bias = new_inp[2]; + } auto new_param = conv_bias_opr.param(); new_param.format = megdnn::param::ConvBias::Format::NHWCD4; mgb_assert(conv_bias_src->shape().ndim == 5 && conv_bias_src->format().type() == TensorFormat::Type::IMAGE2D_PACK4); - auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_weights, conv_bias_bias, new_param, - conv_bias_opr.execution_policy(), conv_bias_opr.config()); + SymbolVar new_conv_bias_opr; + if (has_bias) { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_weights, conv_bias_bias, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + } else { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_weights, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + } OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr(); mgb_assert(new_conv_bias_opr.shape().ndim == 5 && new_conv_bias_opr.format().type() ==