提交 6e497cbb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4607 Supplement custom parser

Merge pull request !4607 from lyvette/master
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserL2Norm : public TestTfliteParser {
public:
TestTfliteParserL2Norm() = default;
void SetUp() override { meta_graph = LoadAndConvert("./l2norm.tflite", ""); }
};
TEST_F(TestTfliteParserL2Norm, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_L2Norm) << "wrong Op Type";
}
TEST_F(TestTfliteParserL2Norm, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsL2Norm(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsL2Norm();
ASSERT_EQ(val->epsilon, 0.0);
std::vector<int32_t> axis = {0, 1, 2, 3};
ASSERT_EQ(val->axis, axis);
}
} // namespace mindspore
......@@ -92,8 +92,6 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_l2norm_parser.h"
#include <vector>
#include <memory>
#include <map>
namespace mindspore {
namespace lite {
STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteL2NormParser";
// set attr
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::L2NormT> attr(new schema::L2NormT());
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_tensors[data_index];
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the input tensor is null";
return RET_NULL_PTR;
}
auto ndim = data_tensor->shape.size();
std::vector<int32_t> axis;
axis.reserve(ndim);
for (int i = 0; i < ndim; i++) {
axis.emplace_back(i);
}
attr->axis = axis;
attr->epsilon = 0.0f;
op->primitive->value.type = schema::PrimitiveType_L2Norm;
op->primitive->value.value = attr.release();
// set input
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}
TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteL2NormParser : public TfliteNodeParser {
public:
TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_L2NORM_PARSER_H
......@@ -96,6 +96,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
for (const auto &tflite_op : tflite_subgraph->operators) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto op_type = GetMSOpType(tflite_op_type);
if (op_type == "CUSTOM") {
auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code;
MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type;
return RET_ERROR;
}
std::unique_ptr<schema::CNodeT> op(new schema::CNodeT);
op->name = op_type + "-" + std::to_string(idx++);
......@@ -219,7 +224,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT>
return RET_OK;
}
STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) {
for (auto &op : sub_graph->nodes) {
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
auto attr = op->primitive->value.AsDepthwiseConv2D();
......@@ -268,7 +273,6 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
auto weight_id = op->inputIndex[1];
auto &weight_tensor = sub_graph->allTensors.at(weight_id);
if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
......@@ -276,13 +280,13 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
}
}
if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
MS_LOG(ERROR) << "Trans filter format failed.";
return RET_ERROR;
}
}
weight_tensor->format = schema::Format_CHWK;
}
}
}
......@@ -323,8 +327,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const std::s
}
// update for depthwiseConv
if (UpdateOp(sub_graph.get()) != RET_OK) {
MS_LOG(ERROR) << "update depthwise conv failed";
if (ConvertGroupDepthwiseOp(sub_graph.get()) != RET_OK) {
MS_LOG(ERROR) << "convert group depthwise conv failed";
return nullptr;
}
......
......@@ -67,7 +67,7 @@ class TfliteModelParser : public ModelParser {
STATUS GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
schema::MetaGraphT* sub_graph);
STATUS UpdateOp(schema::MetaGraphT* sub_graph);
STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph);
private:
std::vector<int32_t> tensorsId;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册