From 39ac3273a8cdfaa724691581b353d33913ed9b3a Mon Sep 17 00:00:00 2001 From: kai00 Date: Mon, 3 Aug 2020 17:14:55 +0800 Subject: [PATCH] fix anf exporter --- .../src/common/anf_exporter/anf_exporter.cc | 3 +++ .../anf_populater/anf_reducemean_populater.cc | 24 ++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index e375bcaf4..1e36068b3 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -241,6 +241,9 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); meta_graph->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + MS_LOG(INFO) << "Value type is ValueSequence."; + break; } else { MS_LOG(ERROR) << "Not support value type , need add support."; } diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc index e7f5f71ff..8ec0a93cf 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc @@ -21,6 +21,10 @@ #include "ir/primitive.h" namespace mindspore::lite { +namespace { + constexpr int kReduceInputNum = 3; + constexpr int kReduceInputIndex = 2; +} int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) { auto p = GetCNodePrimitive(cnodePtr); @@ -28,7 +32,25 @@ int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CN attr->mode = schema::ReduceMode_ReduceMean; attr->keepDims = GetValue(p->GetAttr("keep_dims")); - // attr->axes = GetValue>(p->GetAttr("shape")); + if (cnodePtr->inputs().size() == kReduceInputNum) { + auto inputNode = cnodePtr->input(kReduceInputIndex); + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto valueNode = inputNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto value = valueNode->value(); + MS_ASSERT(value != nullptr); + if (value->isa()) { + auto valTuplPtr = dyn_cast(value); + MS_ASSERT(valTuplPtr != nullptr); + for (size_t i = 0; i < valTuplPtr->size(); i++) { + auto elem = dyn_cast((*valTuplPtr)[i]); + MS_ASSERT(elem != nullptr); + attr->axes.emplace_back(elem->value()); + } + } + } + } node->nodeType = schema::NodeType_CNode; node->primitive = std::make_unique(); -- GitLab