提交 5ca7be57 编写于 作者: K kai00

anf exporter fixed

上级 58523a41
......@@ -387,7 +387,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
}
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D)
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)) {
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) {
break;
}
}
......
......@@ -29,7 +29,7 @@
namespace mindspore::lite {
void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive,
const int &group) {
const int &group, const std::vector<AnfNodePtr> &inputs) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
......@@ -66,6 +66,28 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
attr->padMode = schema::PadMode_NOTSET;
}
int channel_mutiplier = 1;
if (prim->GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim->GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
auto inputNode = inputs[kAnfPopulaterOne];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>();
auto abstractBase = paramNode->abstract();
MS_ASSERT(abstractBase != nullptr);
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
MS_ASSERT(abstractTensor != nullptr);
if (abstractTensor->format() == schema::Format_NCHW) {
abstractTensor->set_format(schema::Format_KCHW);
}
}
}
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
}
......@@ -214,7 +236,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
int group = GetValue<int>(prim->GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, primitive, group);
PopulaterConv2DMultiGroup(prim, primitive, group, inputs);
} else {
PopulaterConv2DSingleGroup(prim, primitive, group);
}
......
......@@ -35,7 +35,7 @@ class AnfConvPopulater : public AnfNodePopulater {
private:
void PopulaterConv2DMultiGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group, const std::vector<AnfNodePtr> &inputs);
void PopulaterConv2DSingleGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
......
......@@ -1129,7 +1129,12 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
auto primReturn = std::make_unique<schema::PrimitiveT>();
MS_ASSERT(primReturn != nullptr);
primReturn->value.type = schema::PrimitiveType_Return;
std::shared_ptr<PrimitiveTValue> primitiveTReturnValuePtr = std::make_shared<PrimitiveTValue>(primReturn.release());
MS_ASSERT(primitiveTReturnValuePtr != nullptr);
inputs.push_back(NewValueNode(primitiveTReturnValuePtr));
inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
......
......@@ -18,6 +18,7 @@
#include "tools/common/converter_op_utils.h"
#include "utils/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/node_util.h"
namespace mindspore {
namespace lite {
......@@ -166,6 +167,9 @@ STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr<CNodeT> &node,
if (opType == PrimitiveType_Conv2D) {
weightTensor->format = Format_KCHW;
} else if (opType == PrimitiveType_DepthwiseConv2D) {
if (weightTensor->format == Format_KCHW) {
TransFilterFormat<float>(weightTensor.get(), kKCHW2CKHW);
}
weightTensor->format = Format_CKHW;
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册