diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 68bf5ff4b716588b027a63d1ce0de3ff3cb2e881..b4541097ff26aeab82aa1363fd463ea87ac16bc8 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -387,7 +387,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrallTensors.emplace_back(msTensor); if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) - || IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)) { + || IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) + || IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) { break; } } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc index ac38c7967136fc2abd4360d60ef2d2caecd64462..523c73501b51532fcefc392a3da33aae33ce8740 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc @@ -29,7 +29,7 @@ namespace mindspore::lite { void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr &primitive, - const int &group) { + const int &group, const std::vector &inputs) { auto attr = std::make_unique(); auto format = GetValue(prim->GetAttr("data_format")); if (format == "NCHW") { @@ -66,6 +66,28 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim, attr->padMode = schema::PadMode_NOTSET; } + int channel_mutiplier = 1; + if (prim->GetAttr("channel_mutiplier") != nullptr) { + channel_mutiplier = GetValue(prim->GetAttr("channel_multiplier")); + } + attr->channelMultiplier = channel_mutiplier; + + MS_ASSERT(inputs.size() == kAnfPopulaterTwo); + auto inputNode = inputs[kAnfPopulaterOne]; + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto paramNode = inputNode->cast(); + auto abstractBase = paramNode->abstract(); + MS_ASSERT(abstractBase != nullptr); + if (utils::isa(abstractBase)) { + auto abstractTensor = utils::cast(abstractBase); + MS_ASSERT(abstractTensor != nullptr); + if (abstractTensor->format() == schema::Format_NCHW) { + abstractTensor->set_format(schema::Format_KCHW); + } + } + } + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; primitive->value.value = attr.release(); } @@ -214,7 +236,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit int group = GetValue(prim->GetAttr("group")); if (group > 1) { - PopulaterConv2DMultiGroup(prim, primitive, group); + PopulaterConv2DMultiGroup(prim, primitive, group, inputs); } else { PopulaterConv2DSingleGroup(prim, primitive, group); } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h index 678897fe92eced506d9aa9dba17ef35b2e19d792..7386b24afe20eb74c3defaa08499d49cb4ec109f 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h @@ -35,7 +35,7 @@ class AnfConvPopulater : public AnfNodePopulater { private: void PopulaterConv2DMultiGroup( const PrimitivePtr &prim, - const std::unique_ptr &primitive, const int &group); + const std::unique_ptr &primitive, const int &group, const std::vector &inputs); void PopulaterConv2DSingleGroup( const PrimitivePtr &prim, const std::unique_ptr &primitive, const int &group); diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index 3b5b8c40b0e6831daeb720c4b14e91dfa260f570..9b7991d8d6667a79ed2a61a5184430eeb2f3fb52 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -1129,7 +1129,12 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output auto abstract_tensor = std::make_shared(type_ptr, output_shape); inputs.clear(); - inputs.push_back(NewValueNode(prim::kPrimReturn)); + auto primReturn = std::make_unique(); + MS_ASSERT(primReturn != nullptr); + primReturn->value.type = schema::PrimitiveType_Return; + std::shared_ptr primitiveTReturnValuePtr = std::make_shared(primReturn.release()); + MS_ASSERT(primitiveTReturnValuePtr != nullptr); + inputs.push_back(NewValueNode(primitiveTReturnValuePtr)); inputs.push_back(cnode_ptr); auto return_node = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(return_node); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc index b8aa778b906815ad590bd4f0686128e491f2db52..5c25e73a6f5a1444375149c86ebe3d7b9e8a0c05 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc @@ -18,6 +18,7 @@ #include "tools/common/converter_op_utils.h" #include "utils/log_adapter.h" #include "src/common/utils.h" +#include "tools/common/node_util.h" namespace mindspore { namespace lite { @@ -166,6 +167,9 @@ STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr &node, if (opType == PrimitiveType_Conv2D) { weightTensor->format = Format_KCHW; } else if (opType == PrimitiveType_DepthwiseConv2D) { + if (weightTensor->format == Format_KCHW) { + TransFilterFormat(weightTensor.get(), kKCHW2CKHW); + } weightTensor->format = Format_CKHW; } else { MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;