提交 fdf19edf 编写于 作者: L lyvette

merge similar parsers.

modify parser format
上级 f287471f
......@@ -228,6 +228,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Elu(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_DeDepthwiseConv2D:
return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Shape:
return new lite::Shape(const_cast<schema::Primitive *>(srcPrim));
default:
break;
}
......
......@@ -26,6 +26,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/slice.h"
#include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h"
#include "src/runtime/kernel/arm/nnacl/reshape_parameter.h"
#include "src/runtime/kernel/arm/nnacl/shape.h"
#include "src/runtime/kernel/arm/nnacl/fp32/stack.h"
#include "src/runtime/kernel/arm/nnacl/unstack.h"
#include "src/runtime/kernel/arm/nnacl/depth_to_space.h"
......@@ -874,6 +875,16 @@ OpParameter *PopulateReshapeParameter(const lite::Primitive *primitive) {
return reinterpret_cast<OpParameter *>(reshape_param);
}
OpParameter *PopulateShapeParameter(const lite::Primitive *primitive) {
ShapeParameter *shape_param = new (std::nothrow) ShapeParameter();
if (shape_param == nullptr) {
MS_LOG(ERROR) << "new ShapeParameter failed.";
return nullptr;
}
shape_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(shape_param);
}
OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) {
auto reverse_attr = primitive->Value()->value_as_Reverse();
ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter();
......@@ -1306,6 +1317,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter;
populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter;
populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Shape] = PopulateShapeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter;
populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter;
populate_parameter_funcs_[schema::PrimitiveType_TopK] = PopulateTopKParameter;
......
......@@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
cp -fr $TEST_DATA_DIR/testPK ./data
./lite-test --gtest_filter="*MindDataTestTensorDE*"
./lite-test --gtest_filter="*MindDataTestEager*"
./lite-test --gtest_filter="TestTfliteParser*"
./lite-test --gtest_filter="*TestHebing*"
./lite-test --gtest_filter=TestFcFp32*
./lite-test --gtest_filter=TestConv1x1Fp32*
./lite-test --gtest_filter=TestStrassenFp32*
./lite-test --gtest_filter=TestDeConvolutionFp32*
./lite-test --gtest_filter=TestPadInt8.*
./lite-test --gtest_filter=TestDeconvInt8.*
#./lite-test --gtest_filter="*MindDataTestTensorDE*"
#./lite-test --gtest_filter="*MindDataTestEager*"
#
#./lite-test --gtest_filter="TestTfliteParser*"
#
#./lite-test --gtest_filter="*TestHebing*"
#
#./lite-test --gtest_filter=TestFcFp32*
#./lite-test --gtest_filter=TestConv1x1Fp32*
#./lite-test --gtest_filter=TestStrassenFp32*
#./lite-test --gtest_filter=TestDeConvolutionFp32*
#
#./lite-test --gtest_filter=TestPadInt8.*
#./lite-test --gtest_filter=TestDeconvInt8.*
./lite-test --gtest_filter="TestTfliteParser*"
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_abs_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteAbsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteAbsParser";
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_ABS_PARSER_H
#define PREDICT_TFLITE_ABS_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteAbsParser : public TfliteNodeParser {
public:
TfliteAbsParser() : TfliteNodeParser("Abs") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_ABS_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include <string>
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU;
} else if (std::strcmp(node_name, "Relu6") == 0) {
MS_LOG(DEBUG) << "parse TfliteRelu6Parser";
attr->type = schema::ActivationType_RELU6;
} else if (std::strcmp(node_name, "Tanh") == 0) {
MS_LOG(DEBUG) << "parse TfliteTanhParser";
attr->type = schema::ActivationType_TANH;
} else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID;
} else {
MS_LOG(ERROR) << "wrong activation type";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "paser TflitePreluParser";
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());
const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->negativeSlope = tflite_attr->alpha;
op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
op->primitive->value.value = attr.release();
return RET_OK;
}
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_LEAKY_RELU_PARSER_H
#define PREDICT_TFLITE_LEAKY_RELU_PARSER_H
#ifndef PREDICT_TFLITE_RELU_PARSER_H
#define PREDICT_TFLITE_RELU_PARSER_H
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
......@@ -24,6 +24,49 @@
namespace mindspore {
namespace lite {
class TfliteActivationParser : public TfliteNodeParser {
public:
TfliteActivationParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteReluParser : public TfliteActivationParser {
public:
TfliteReluParser() : TfliteActivationParser() {}
};
class TfliteRelu6Parser : public TfliteActivationParser{
public:
TfliteRelu6Parser() : TfliteActivationParser() {}
};
class TfliteTanhParser : public TfliteActivationParser{
public:
TfliteTanhParser() : TfliteActivationParser() {}
};
class TfliteLogisticParser : public TfliteActivationParser {
public:
TfliteLogisticParser() : TfliteActivationParser() {}
};
class TflitePreluParser : public TfliteNodeParser {
public:
TflitePreluParser() : TfliteNodeParser("Prelu") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override;
};
class TfliteLeakyReluParser : public TfliteNodeParser {
public:
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}
......@@ -34,8 +77,9 @@ class TfliteLeakyReluParser : public TfliteNodeParser {
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_LEAKY_RELU_PARSER_H
#endif // PREDICT_TFLITE_RELU_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_add_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
return RET_NULL_PTR;
}
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
return RET_NULL_PTR;
}
if (x_data->data.size() > 0) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}
auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
return RET_NULL_PTR;
}
if (y_data->data.size() > 0) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_ADD_PARSER_H
#define PREDICT_TFLITE_ADD_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteAddParser : public TfliteNodeParser {
public:
TfliteAddParser() : TfliteNodeParser("Add") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_ADD_PARSER_H
......@@ -26,16 +26,23 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteAddNParser";
std::unique_ptr<schema::AddNT> attr(new schema::AddNT());
attr->N = tfliteTensors.size() - 1;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -27,6 +27,16 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
......@@ -49,11 +59,8 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -25,6 +25,16 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteArgminParser";
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());
......@@ -47,11 +57,8 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_arithmetic_parser.h"
#include <vector>
#include <memory>
#include <string>
namespace mindspore {
namespace lite {
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Add") == 0
|| std::strcmp(node_name, "Sub") == 0
|| std::strcmp(node_name, "Mul") == 0
|| std::strcmp(node_name, "Div") == 0) {
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
return RET_NULL_PTR;
}
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
return RET_NULL_PTR;
}
if (!x_data->data.empty()) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}
auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
return RET_NULL_PTR;
}
if (!y_data->data.empty()) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}
if (std::strcmp(node_name, "Add") == 0) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sub") == 0) {
MS_LOG(DEBUG) << "parse TfliteSubParser";
std::unique_ptr<schema::SubT> attr(new schema::SubT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Mul") == 0) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr(new schema::MulT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Div") == 0) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr(new schema::DivT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
return RET_OK;
}
} else if (std::strcmp(node_name, "FloorDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
op->primitive->value.type = schema::PrimitiveType_FloorDiv;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "FloorMod") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
op->primitive->value.type = schema::PrimitiveType_FloorMod;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "RealDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteRealDivParser";
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
op->primitive->value.type = schema::PrimitiveType_RealDiv;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "SquaredDifference") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT());
op->primitive->value.type = schema::PrimitiveType_SquaredDifference;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Pow") == 0) {
MS_LOG(DEBUG) << "parse TflitePowParser";
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
attr->power = 0.0f;
attr->scale = 1.0f;
attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Maximum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
op->primitive->value.type = schema::PrimitiveType_Maximum;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Minimum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
op->primitive->value.type = schema::PrimitiveType_Minimum;
op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
}
return RET_OK;
}
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Abs") == 0) {
MS_LOG(DEBUG) << "parse TfliteAbsParser";
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Exp") == 0) {
MS_LOG(DEBUG) << "parse TfliteExpParser";
std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Rsqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT());
op->primitive->value.type = schema::PrimitiveType_Rsqrt;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Square") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquareParser";
std::unique_ptr<schema::SquareT> attr(new schema::SquareT());
op->primitive->value.type = schema::PrimitiveType_Square;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sin") == 0) {
MS_LOG(DEBUG) << "parse TfliteSinParser";
std::unique_ptr<schema::SinT> attr(new schema::SinT());
op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Cos") == 0) {
MS_LOG(DEBUG) << "parse TfliteCosParser";
std::unique_ptr<schema::CosT> attr(new schema::CosT());
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Log") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogParser";
std::unique_ptr<schema::LogT> attr(new schema::LogT());
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Round") == 0) {
MS_LOG(DEBUG) << "parse TfliteRoundParser";
std::unique_ptr<schema::RoundT> attr(new schema::RoundT());
op->primitive->value.type = schema::PrimitiveType_Round;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Ceil") == 0) {
MS_LOG(DEBUG) << "parse TfliteCeilParser";
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "flOOR") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorParser";
std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
}
}
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Equal") == 0) {
MS_LOG(DEBUG) << "parse TfliteEqualParser";
std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "NotEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
op->primitive->value.type = schema::PrimitiveType_NotEqual;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Greater") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "GreaterEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Less") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessParser";
std::unique_ptr<schema::LessT> attr(new schema::LessT());
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "LessEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
op->primitive->value.type = schema::PrimitiveType_LessEqual;
op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
}
}
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());
TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteSubParser());
TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser());
TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDivParser());
TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteFloorDivParser());
TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteFloorModParser());
TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteRealDivParser());
TfliteNodeRegister g_TflitePowParser("Pow", new TflitePowParser());
TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteSquaredDifferenceParser());
TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteMaximumParser());
TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteMinimumParser());
TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser());
TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteExpParser());
TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSqrtParser());
TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteRsqrtParser());
TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSquareParser());
TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSinParser());
TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteCosParser());
TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser());
TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser());
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser());
TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser());
TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser());
TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser());
TfliteNodeRegister g_tfliteGreaterEParser("Greater", new TfliteGreaterParser());
TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteGreaterEqualParser());
TfliteNodeRegister g_tfliteLessParser("Less", new TfliteLessParser());
TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteLessEqualParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_MATH_PARSER_H
#define PREDICT_TFLITE_MATH_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteDoubleInputOpParser : public TfliteNodeParser {
public:
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteAddParser : public TfliteDoubleInputOpParser {
public:
TfliteAddParser() : TfliteDoubleInputOpParser() {}
};
class TfliteSubParser : public TfliteDoubleInputOpParser {
public:
TfliteSubParser() : TfliteDoubleInputOpParser() {}
};
class TfliteMulParser : public TfliteDoubleInputOpParser {
public:
TfliteMulParser() : TfliteDoubleInputOpParser() {}
};
class TfliteDivParser : public TfliteDoubleInputOpParser {
public:
TfliteDivParser() : TfliteDoubleInputOpParser() {}
};
class TfliteFloorDivParser : public TfliteDoubleInputOpParser {
public:
TfliteFloorDivParser() : TfliteDoubleInputOpParser() {}
};
class TfliteFloorModParser : public TfliteDoubleInputOpParser {
public:
TfliteFloorModParser() : TfliteDoubleInputOpParser() {}
};
class TfliteSquaredDifferenceParser : public TfliteDoubleInputOpParser {
public:
TfliteSquaredDifferenceParser() : TfliteDoubleInputOpParser() {}
};
class TfliteRealDivParser : public TfliteDoubleInputOpParser {
public:
TfliteRealDivParser() : TfliteDoubleInputOpParser() {}
};
class TflitePowParser : public TfliteDoubleInputOpParser {
public:
TflitePowParser() : TfliteDoubleInputOpParser() {}
};
class TfliteMaximumParser : public TfliteDoubleInputOpParser {
public:
TfliteMaximumParser() : TfliteDoubleInputOpParser() {}
};
class TfliteMinimumParser : public TfliteDoubleInputOpParser {
public:
TfliteMinimumParser() : TfliteDoubleInputOpParser() {}
};
class TfliteSingleInputOpParser : public TfliteNodeParser {
public:
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteAbsParser : public TfliteSingleInputOpParser {
public:
TfliteAbsParser() : TfliteSingleInputOpParser() {}
};
class TfliteExpParser : public TfliteSingleInputOpParser {
public:
TfliteExpParser() : TfliteSingleInputOpParser() {}
};
class TfliteSqrtParser : public TfliteSingleInputOpParser {
public:
TfliteSqrtParser() : TfliteSingleInputOpParser() {}
};
class TfliteSquareParser : public TfliteSingleInputOpParser {
public:
TfliteSquareParser() : TfliteSingleInputOpParser() {}
};
class TfliteSinParser : public TfliteSingleInputOpParser {
public:
TfliteSinParser() : TfliteSingleInputOpParser() {}
};
class TfliteCosParser : public TfliteSingleInputOpParser {
public:
TfliteCosParser() : TfliteSingleInputOpParser() {}
};
class TfliteRsqrtParser : public TfliteSingleInputOpParser {
public:
TfliteRsqrtParser() : TfliteSingleInputOpParser() {}
};
class TfliteLogParser : public TfliteSingleInputOpParser {
public:
TfliteLogParser() : TfliteSingleInputOpParser() {}
};
class TfliteRoundParser : public TfliteSingleInputOpParser {
public:
TfliteRoundParser() : TfliteSingleInputOpParser() {}
};
class TfliteCeilParser : public TfliteSingleInputOpParser {
public:
TfliteCeilParser() : TfliteSingleInputOpParser() {}
};
class TfliteFloorParser : public TfliteSingleInputOpParser {
public:
TfliteFloorParser() : TfliteSingleInputOpParser() {}
};
class TfliteCompareOpParser : public TfliteNodeParser {
public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteEqualParser : public TfliteCompareOpParser {
public:
TfliteEqualParser() : TfliteCompareOpParser() {}
};
class TfliteNotEqualParser : public TfliteCompareOpParser {
public:
TfliteNotEqualParser() : TfliteCompareOpParser() {}
};
class TfliteGreaterParser : public TfliteCompareOpParser {
public:
TfliteGreaterParser() : TfliteCompareOpParser() {}
};
class TfliteGreaterEqualParser : public TfliteCompareOpParser {
public:
TfliteGreaterEqualParser() : TfliteCompareOpParser() {}
};
class TfliteLessParser : public TfliteCompareOpParser {
public:
TfliteLessParser() : TfliteCompareOpParser() {}
};
class TfliteLessEqualParser : public TfliteCompareOpParser {
public:
TfliteLessEqualParser() : TfliteCompareOpParser() {}
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_MATH_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteBatchToSpaceNDParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteBatchToSpaceNDParser";
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
MS_LOG(ERROR) << "get BatchToSpaceNd -> blockShape failed";
return RET_ERROR;
}
if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) {
MS_LOG(ERROR) << "get BatchToSpaceNd -> crops failed";
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H
#define PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteBatchToSpaceNDParser : public TfliteNodeParser {
public:
TfliteBatchToSpaceNDParser() : TfliteNodeParser("BatchToSpaceND") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H
......@@ -18,6 +18,7 @@
#include "tools/converter/parser/tflite/tflite_batch_to_space_parser.h"
#include <vector>
#include <memory>
#include <string>
namespace mindspore {
namespace lite {
......@@ -26,7 +27,28 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "BatchToSpace") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
}
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
......@@ -38,14 +60,13 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
return RET_OK;
}
TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser());
TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser());
} // namespace lite
} // namespace mindspore
......@@ -32,9 +32,14 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
TensorCache *tensor_cache, bool quantized_model) override;
};
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
public:
TfliteBatchToSpaceNDParser() : TfliteBatchToSpaceParser() {}
};
} // namespace lite
} // namespace mindspore
......
......@@ -26,6 +26,16 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());
......@@ -34,11 +44,8 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -26,6 +26,16 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteCastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT());
......@@ -43,11 +53,8 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
}
attr->dstT = dtype_map[out_tensor->type];
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_ceil_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteCeilParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteCeilParser";
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_CEIL_PARSER_H
#define PREDICT_TFLITE_CEIL_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteCeilParser : public TfliteNodeParser {
public:
TfliteCeilParser() : TfliteNodeParser("Ceil") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_CEIL_PARSER_H
......@@ -25,6 +25,16 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteConcatParser";
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
......@@ -37,11 +47,8 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->n = tfliteOp->inputs.size();
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -25,6 +25,16 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteConvParser";
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions();
......@@ -49,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
......@@ -69,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
......@@ -77,11 +87,8 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
// calculate pad params
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_
#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
#include <string>
#include <memory>
......@@ -34,5 +34,5 @@ class TfliteConverter : public Converter {
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_cos_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteCosParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteCosParser";
std::unique_ptr<schema::CosT> attr(new schema::CosT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteCosParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_COS_PARSER_H
#define PREDICT_TFLITE_COS_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteCosParser : public TfliteNodeParser {
public:
TfliteCosParser() : TfliteNodeParser("Cos") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_COS_PARSER_H
......@@ -25,6 +25,16 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT());
const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions();
......@@ -49,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
return RET_ERROR;
}
auto weight_shape = weight_tensor->shape;
......@@ -58,11 +68,8 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->kernelW = weight_shape[CHWK_W];
attr->kernelH = weight_shape[CHWK_H];
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -26,6 +26,16 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());
......@@ -38,11 +48,8 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->format = schema::Format_NHWC;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -66,11 +66,8 @@ STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op,
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = convAttr.release();
}
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = convAttr.release();
return RET_OK;
}
......@@ -79,6 +76,16 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT());
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
......@@ -96,10 +103,18 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
// get the conv op weight tensor
auto input_index = tflite_op->inputs[0];
const auto &input_tenosr = tflite_tensors[input_index];
if (input_tenosr == nullptr) {
MS_LOG(ERROR) << "the first input is null";
return RET_NULL_PTR;
}
auto input_shape = input_tenosr->shape;
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = input_shape[KHWC_C];
attr->channelMultiplier = tflite_attr->depth_multiplier;
......@@ -108,7 +123,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
......@@ -118,7 +133,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
auto bias_index = tflite_op->inputs[2];
const auto &bias_tensor = tflite_tensors[bias_index];
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
......@@ -126,11 +141,10 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
if (attr->channelMultiplier > 1) {
if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) {
// MS_LOGE("Parse Group DepthwiseConv failed");
MS_LOG(ERROR) << "Parse Group DepthwiseConv failed";
return RET_ERROR;
}
} else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = attr.release();
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_div_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteDivParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr(new schema::DivT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
return RET_NULL_PTR;
}
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
return RET_NULL_PTR;
}
if (x_data->data.size() > 0) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}
auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
return RET_NULL_PTR;
}
if (y_data->data.size() > 0) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDivParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_DIV_PARSER_H
#define PREDICT_TFLITE_DIV_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteDivParser : public TfliteNodeParser {
public:
TfliteDivParser() : TfliteNodeParser("Div") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_DIV_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_equal_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteEqualParser";
std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_EQUAL_PARSER_H
#define LITE_TFLITE_EQUAL_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteEqualParser : public TfliteNodeParser {
public:
TfliteEqualParser() : TfliteNodeParser("Equal") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_EQUAL_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_exp_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteExpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteExpParser";
std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteExpParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_EXP_PARSER_H
#define PREDICT_TFLITE_EXP_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteExpParser : public TfliteNodeParser {
public:
TfliteExpParser() : TfliteNodeParser("Exp") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_EXP_PARSER_H
......@@ -27,6 +27,16 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT());
......
......@@ -24,6 +24,16 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
......@@ -34,7 +44,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
......@@ -48,18 +58,15 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
attr->axis = 1;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_FullConnection;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_FullConnection;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -27,6 +27,16 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteFillParser";
std::unique_ptr<schema::FillT> attr(new schema::FillT());
......@@ -37,11 +47,8 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Fill;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Fill;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_floor_div_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteFloorDivParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_FloorDiv;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteFloorDivParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_FLOOR_DIV_PARSER_H
#define PREDICT_TFLITE_FLOOR_DIV_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteFloorDivParser : public TfliteNodeParser {
public:
TfliteFloorDivParser() : TfliteNodeParser("FloorDiv") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_FLOOR_DIV_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_floor_mod_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteFloorModParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_FloorMod;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteFloorModParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_FLOOR_MOD_PARSER_H
#define PREDICT_TFLITE_FLOOR_MOD_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteFloorModParser : public TfliteNodeParser {
public:
TfliteFloorModParser() : TfliteNodeParser("FloorMod") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_FLOOR_MOD_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_floor_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteFloorParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteFloorParser";
std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_FLOOR_PARSER_H
#define PREDICT_TFLITE_FLOOR_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteFloorParser : public TfliteNodeParser {
public:
TfliteFloorParser() : TfliteNodeParser("flOOR") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_FLOOR_PARSER_H
......@@ -25,6 +25,16 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
......@@ -35,7 +45,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
......@@ -49,18 +59,15 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
attr->axis = 1;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_FullConnection;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_FullConnection;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -27,16 +27,23 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteGatherNdParser";
std::unique_ptr<schema::GatherNdT> attr(new schema::GatherNdT());
attr->batchDims = 0;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_GatherNd;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_GatherNd;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -27,6 +27,16 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteGatherParser";
std::unique_ptr<schema::GatherT> attr(new schema::GatherT());
......@@ -39,11 +49,8 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->batchDims = 0;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Gather;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Gather;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_greater_equal_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteGreaterEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteGreaterEqualParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_GREATER_EQUAL_PARSER_H
#define LITE_TFLITE_GREATER_EQUAL_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteGreaterEqualParser : public TfliteNodeParser {
public:
TfliteGreaterEqualParser() : TfliteNodeParser("GreaterEqual") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_GREATER_EQUAL_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_greater_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteGreaterParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteGreaterParser("Greater", new TfliteGreaterParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_GREATER_PARSER_H
#define LITE_TFLITE_GREATER_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteGreaterParser : public TfliteNodeParser {
public:
TfliteGreaterParser() : TfliteNodeParser("Greater") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_GREATER_PARSER_H
......@@ -25,16 +25,23 @@ STATUS TfliteHardSwishParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(INFO) << "parse TfliteHardSwishParser";
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
attr->type = schema::ActivationType_HSWISH;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_leaky_relu_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_LEAKY_RELU;
attr->alpha = tflite_attr->alpha;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_less_equal_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLessEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_LessEqual;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteLessEqualParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_LESS_EQUAL_PARSER_H
#define LITE_TFLITE_LESS_EQUAL_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteLessEqualParser : public TfliteNodeParser {
public:
TfliteLessEqualParser() : TfliteNodeParser("LessEqual") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_LESS_EQUAL_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_less_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLessParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteLessParser";
std::unique_ptr<schema::LessT> attr(new schema::LessT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteLessParser("Less", new TfliteLessParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_LESS_PARSER_H
#define LITE_TFLITE_LESS_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteLessParser : public TfliteNodeParser {
public:
TfliteLessParser() : TfliteNodeParser("Less") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_LESS_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_log_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLogParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteLogParser";
std::unique_ptr<schema::LogT> attr(new schema::LogT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_LOG_PARSER_H
#define PREDICT_TFLITE_LOG_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteLogParser : public TfliteNodeParser {
public:
TfliteLogParser() : TfliteNodeParser("Log") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_LOG_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_logical_and_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLogicalAndParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteLogicalAndParser";
std::unique_ptr<schema::LogicalAndT> attr(new schema::LogicalAndT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_LogicalAnd;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteLogicalAndParser("LogicalAnd", new TfliteLogicalAndParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_logical_not_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLogicalNotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteLogicalNotParser";
std::unique_ptr<schema::LogicalNotT> attr(new schema::LogicalNotT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_LogicalNot;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteLogicalNotParser("LogicalNot", new TfliteLogicalNotParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_LOGICAL_NOT_PARSER_H
#define PREDICT_TFLITE_LOGICAL_NOT_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteLogicalNotParser : public TfliteNodeParser {
public:
TfliteLogicalNotParser() : TfliteNodeParser("LogicalNot") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_LOGICAL_NOT_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_logical_or_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLogicalOrParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteLogicalOrParser";
std::unique_ptr<schema::LogicalOrT> attr(new schema::LogicalOrT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_LogicalOr;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteLogicalOrParser("LogicalOr", new TfliteLogicalOrParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_LOGICAL_OR_PARSER_H
#define PREDICT_TFLITE_LOGICAL_OR_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteLogicalOrParser : public TfliteNodeParser {
public:
TfliteLogicalOrParser() : TfliteNodeParser("LogicalOr") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_LOGICAL_OR_PARSER_H
......@@ -16,26 +16,55 @@
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_maximum_parser.h"
#include <string>
#include "tools/converter/parser/tflite/tflite_logical_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteMaximumParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
STATUS TfliteLogicalParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteMaximumParser";
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Maximum;
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "LogicalAnd") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogicalAndParser";
std::unique_ptr<schema::LogicalAndT> attr(new schema::LogicalAndT());
op->primitive->value.type = schema::PrimitiveType_LogicalAnd;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "LogicalNot") == 0) {
MS_LOG(INFO) << "parse TfliteLogicalNotParser";
std::unique_ptr<schema::LogicalNotT> attr(new schema::LogicalNotT());
op->primitive->value.type = schema::PrimitiveType_LogicalNot;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "LogicalOr") == 0) {
MS_LOG(INFO) << "parse TfliteLogicalOrParser";
std::unique_ptr<schema::LogicalOrT> attr(new schema::LogicalOrT());
op->primitive->value.type = schema::PrimitiveType_LogicalOr;
op->primitive->value.value = attr.release();
} else {
MS_LOG(ERROR) << "wrong logical type";
return RET_ERROR;
}
return RET_OK;
return RET_OK;
}
TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteMaximumParser());
TfliteNodeRegister g_TfliteLogicalAndParser("LogicalAnd", new TfliteLogicalAndParser());
TfliteNodeRegister g_TfliteLogicalNotParser("LogicalNot", new TfliteLogicalNotParser());
TfliteNodeRegister g_TfliteLogicalOrParser("LogicalOr", new TfliteLogicalOrParser());
} // namespace lite
} // namespace mindspore
......@@ -24,9 +24,10 @@
namespace mindspore {
namespace lite {
class TfliteLogicalAndParser : public TfliteNodeParser {
class TfliteLogicalParser : public TfliteNodeParser {
public:
TfliteLogicalAndParser() : TfliteNodeParser("LogicalAnd") {}
TfliteLogicalParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
......@@ -35,6 +36,21 @@ class TfliteLogicalAndParser : public TfliteNodeParser {
TensorCache *tensor_cache,
bool quantizedModel) override;
};
class TfliteLogicalAndParser : public TfliteLogicalParser {
public:
TfliteLogicalAndParser() : TfliteLogicalParser() {}
};
class TfliteLogicalNotParser : public TfliteLogicalParser {
public:
TfliteLogicalNotParser() : TfliteLogicalParser() {}
};
class TfliteLogicalOrParser : public TfliteLogicalParser {
public:
TfliteLogicalOrParser() : TfliteLogicalParser() {}
};
} // namespace lite
} // namespace mindspore
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_logistic_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteLogisticParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
attr->type = schema::ActivationType_SIGMOID;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_LOGISTIC_PARSER_H
#define PREDICT_TFLITE_LOGISTIC_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteLogisticParser : public TfliteNodeParser {
public:
TfliteLogisticParser() : TfliteNodeParser("Logistic") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_CONCAT_PARSER_H
......@@ -27,6 +27,16 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteLRNParser";
std::unique_ptr<schema::LocalResponseNormalizationT> attr(new schema::LocalResponseNormalizationT());
......@@ -40,11 +50,8 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
attr->beta = tflite_attr->beta;
attr->bias = tflite_attr->bias;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_max_pooling_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteMaxPoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser";
std::unique_ptr<schema::PoolingT> attr(new schema::PoolingT());
const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NHWC;
// attr->global
attr->poolingMode = schema::PoolMode_MAX_POOLING;
attr->windowW = tflite_attr->filter_width;
attr->windowH = tflite_attr->filter_height;
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->padMode = GetPadMode(tflite_attr->padding);
// calculate pad params
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Pooling;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TfliteMaxPoolingParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_MAX_POOLING_PARSER_H
#define PREDICT_TFLITE_MAX_POOLING_PARSER_H
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteMaxPoolingParser : public TfliteNodeParser {
public:
TfliteMaxPoolingParser() : TfliteNodeParser("MaxPooling") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_CONV_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_MAXIMUM_PARSER_H
#define PREDICT_TFLITE_MAXIMUM_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteMaximumParser : public TfliteNodeParser {
public:
TfliteMaximumParser() : TfliteNodeParser("Maximum") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_MAXIMUM_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_mean_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteMeanParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteMeanParser";
std::unique_ptr<schema::MeanT> attr(new schema::MeanT());
const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->keepDims = tflite_attr->keep_dims;
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axis)) {
MS_LOG(ERROR) << "Mean get axis attr failed";
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Mean;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteMeanParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_MEAN_PARSER_H
#define PREDICT_TFLITE_MEAN_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteMeanParser : public TfliteNodeParser {
public:
TfliteMeanParser() : TfliteNodeParser("Mean") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_MEAN_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_minimum_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteMinimumParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteMinimumParser";
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Minimum;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteMinimumParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_MINIMUM_PARSER_H
#define PREDICT_TFLITE_MINIMUM_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteMinimumParser : public TfliteNodeParser {
public:
TfliteMinimumParser() : TfliteNodeParser("Minimum") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_MINIMUM_PARSER_H
......@@ -15,7 +15,6 @@
*/
#include "tools/converter/parser/tflite/tflite_model_parser.h"
#include <fstream>
#include <utility>
#include <memory>
#include "tools/common/graph_util.h"
......@@ -71,6 +70,10 @@ STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::S
quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end());
for (const auto &index : quant_params_index) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
return RET_ERROR;
}
if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) {
continue;
......@@ -101,6 +104,10 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
TensorCache *tensorCache) {
for (const auto &index : tflite_op->outputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
return RET_ERROR;
}
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT());
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
tensor->dims = tflite_tensor->shape;
......@@ -108,7 +115,6 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT);
op->outputIndex.emplace_back(opOutputIndex);
}
return RET_OK;
}
......@@ -123,6 +129,10 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t
for (const auto &tflite_index : op_inputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << tflite_index <<" is null";
return RET_ERROR;
}
auto tensor_name = tflite_tensor->name;
auto op = tfliteOpMap[tflite_op.get()];
unsigned int index = tensorCache->FindTensor(tensor_name);
......@@ -144,10 +154,8 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_
std::unique_ptr<schema::CNodeT> op(new schema::CNodeT);
op->name = opType + "-" + std::to_string(i++);
MS_LOG(INFO) << "parse op: " << op->name.c_str();
MS_LOG(INFO) << "parse op: [%s]" << op->name.c_str();
// 1. init op attr params
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType);
if (node_parser == nullptr) {
MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str();
......@@ -164,7 +172,7 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_
status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "Set Op " << op->name.c_str() << " Output Index Failed!";
MS_LOG(ERROR) << "Set Op "<< op->name.c_str() << " Output Index Failed!";
return RET_ERROR;
}
......@@ -175,8 +183,7 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_
return RET_OK;
}
void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
TensorCache *tensor_cache) {
for (const auto &index : tflite_subgraph->inputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
......@@ -206,35 +213,31 @@ void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &
}
MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT);
if (ValidateFileStr(modelFile, ".tflite") != RET_OK) {
// MS_LOGE("INPUT ILLEGAL: modelFile must be *.tflite");
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite";
return nullptr;
}
MS_LOG(INFO) << "modelFile is :" << modelFile;
std::unique_ptr<tflite::ModelT> tflite_model(new tflite::ModelT());
tflite_model = ReadTfliteModelFromFlat(modelFile.c_str());
if (tflite_model == nullptr) {
// MS_LOGE("read tflite model failed");
MS_LOG(ERROR) << "read tflite model failed";
return nullptr;
}
MS_LOG(INFO) << "after read model";
TensorCache tensorCache;
if (tflite_model->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
return nullptr;
}
const auto &tflite_subgraph = tflite_model->subgraphs[0];
subGraph->name = "MS_model converted by TF-Lite";
// set dst subGraph input/output tensor
SetInputTensor(tflite_model, tflite_subgraph, &tensorCache);
// set dst subGraph op attr etc.
TensorCache tensorCache;
SetInputTensor(tflite_subgraph, &tensorCache);
// set dst subGraph op attr and tensor_cache.
std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT);
subGraph->name = "MS_model converted by TF-Lite";
auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "ParseOp failed.";
......@@ -244,21 +247,20 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
for (const auto &tflite_op : tflite_subgraph->operators) {
auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache);
if (status_tmp != RET_OK) {
// MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str());
MS_LOG(ERROR) << "Set Op " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Input Index Failed!";
}
}
for (const auto &tflite_op : tflite_subgraph->operators) {
auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op);
if (statusTmp != RET_OK) {
// MS_LOGE("ParseTfliteQuantParams %s Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str());
MS_LOG(ERROR) << "ParseTfliteQuantParams " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Failed!";
}
}
SetGraphTensorIndex(tensorCache, subGraph.get());
SetAllTensors(tensorCache, subGraph.get());
return subGraph.release();
// return Fb2Anf(subGraph.release());
}
} // namespace lite
} // namespace mindspore
......
......@@ -14,29 +14,24 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
#include <fcntl.h>
#include <unistd.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "securec/include/securec.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include "tools/common/tensor_util.h"
#include "mindspore/lite/schema/inner/model_generated.h"
// using namespace tflite;
namespace mindspore {
namespace lite {
class TfliteModelParser : public ModelParser {
......@@ -50,8 +45,7 @@ class TfliteModelParser : public ModelParser {
private:
std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf);
void SetInputTensor(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);
void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef);
......@@ -82,6 +76,5 @@ class TfliteModelParser : public ModelParser {
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_CONV
// ERTER_PARSER_TFLITE_MODEL_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_mul_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteMulParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr(new schema::MulT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
return RET_NULL_PTR;
}
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
return RET_NULL_PTR;
}
if (x_data->data.size() > 0) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}
auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
return RET_NULL_PTR;
}
if (y_data->data.size() > 0) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_MUL_PARSER_H
#define PREDICT_TFLITE_MUL_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteMulParser : public TfliteNodeParser {
public:
TfliteMulParser() : TfliteNodeParser("Mul") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_MUL_PARSER_H
......@@ -16,80 +16,36 @@
#include <vector>
#include <memory>
#include <unordered_map>
#include "securec/include/securec.h"
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_util.h"
namespace mindspore {
namespace lite {
STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) {
const tflite::TensorT *tflite_tensor,
schema::TensorT *tensor) {
auto count = 1;
std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; });
auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType));
auto buffer_idx = tflite_tensor->buffer;
if (!tfliteModelBuffer[buffer_idx]->data.empty()) {
tensor->data.resize(data_size);
auto ret = memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size);
if (ret) {
MS_LOG(ERROR) << "memcpy tensor data failed, error code: %d" << ret;
return ret;
if (memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size)) {
MS_LOG(ERROR) << "memcpy tensor data failed";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "src tensor data is empty.";
MS_LOG(ERROR) << "src tensor data is empty";
return RET_ERROR;
}
return RET_OK;
}
STATUS TfliteNodeParser::ParseWeight(const std::vector<tflite::TensorT *> &weight_tenosrs,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
mindspore::lite::TensorCache *tensor_cache, schema::Format format) {
for (const auto &weight_tensor : weight_tenosrs) {
auto idx = tensor_cache->FindTensor(weight_tensor->name);
if (idx < 0) {
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT);
tensor->dataType = GetTfliteDataType(weight_tensor->type);
tensor->dims = weight_tensor->shape;
tensor->nodeType = schema::NodeType_ValueNode;
// memcpy tensor data
// buffer is 0 (which refers to an always existent empty buffer)
if (weight_tensor->buffer > 0) {
CopyTfliteTensorData(tfliteModelBuffer, weight_tensor, tensor.get());
}
MS_LOG(DEBUG) << "add weight tensor name: %s", weight_tensor->name.c_str();
tensor_cache->AddTensor(weight_tensor->name, tensor.release(), TF_CONST);
}
}
return RET_OK;
}
STATUS TfliteNodeParser::ParseBias(const std::vector<tflite::TensorT *> &bias_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
TensorCache *tensor_cache) {
for (const auto &bias_tensor : bias_tensors) {
auto idx = tensor_cache->FindTensor(bias_tensor->name);
if (idx < 0) {
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT);
tensor->dataType = GetTfliteDataType(bias_tensor->type);
tensor->dims = bias_tensor->shape;
tensor->nodeType = schema::NodeType_ValueNode;
// memcpy tensor data
// buffer is 0 (which refers to an always existent empty buffer)
if (bias_tensor->buffer > 0) {
CopyTfliteTensorData(tfliteModelBuffer, bias_tensor, tensor.get());
}
// MS_LOGD("add weight tensor name: %s", bias_tensor->name.c_str());
tensor_cache->AddTensor(bias_tensor->name, tensor.release(), TF_CONST);
}
}
return RET_OK;
}
STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
mindspore::lite::TensorCache *tensor_cache, int node_type,
bool ifCopy) {
mindspore::lite::TensorCache *tensor_cache,
int node_type) {
for (const auto &t : ts) {
auto idx = tensor_cache->FindTensor(t->name);
if (idx < 0) {
......@@ -97,29 +53,15 @@ STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts,
tensor->dataType = GetTfliteDataType(t->type);
tensor->dims = t->shape;
// memcpy tensor data, buffer is 0 (which refers to an always existent empty buffer)
if (ifCopy && t->buffer > 0) {
if (t->buffer > 0) {
CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get());
}
MS_LOG(DEBUG) << "add weight tensor name: %s", t->name.c_str();
MS_LOG(DEBUG) << "add tensor name: " << t->name.c_str();
tensor_cache->AddTensor(t->name, tensor.release(), node_type);
}
}
return RET_OK;
}
TypeId TfliteNodeParser::GetTfliteDataType(const tflite::TensorType &tflite_data_type) {
static std::unordered_map<int, TypeId> type_map = {
{tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16},
{tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
};
auto iter = type_map.find(tflite_data_type);
if (iter == type_map.end()) {
return kTypeUnknown;
}
return iter->second;
}
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_NODE_PARSER_H
#define PREDICT_TFLITE_NODE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_H
#include <string>
#include <vector>
......@@ -34,30 +34,24 @@ class TfliteNodeParser {
public:
explicit TfliteNodeParser(const std::string &nodeName) : name(nodeName) {}
virtual ~TfliteNodeParser() {}
virtual ~TfliteNodeParser() = default;
virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) = 0;
STATUS ParseWeight(const std::vector<tflite::TensorT *> &weight_tenosr,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, TensorCache *tensor_cache,
schema::Format format);
STATUS ParseBias(const std::vector<tflite::TensorT *> &weight_tenosr,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, TensorCache *tensor_cache);
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) = 0;
STATUS ParseTensor(const std::vector<tflite::TensorT *> &ts,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
mindspore::lite::TensorCache *tensor_cache, int node_type,
bool ifCopy);
mindspore::lite::TensorCache *tensor_cache,
int node_type);
STATUS CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor);
TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type);
const tflite::TensorT *tflite_tensor,
schema::TensorT *tensor);
template <typename T>
STATUS GetTfliteData(const int32_t tensor_index, const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
......@@ -67,6 +61,10 @@ class TfliteNodeParser {
std::for_each(tfliteTensors[tensor_index]->shape.begin(), tfliteTensors[tensor_index]->shape.end(),
[&](int32_t sha) { count *= sha; });
auto &buf_data = tfliteModelBuffer[tfliteTensors[tensor_index]->buffer];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "buf_data is null";
return RET_NULL_PTR;
}
auto data_ptr = buf_data->data.data();
switch (tfliteTensors[tensor_index]->type) {
case tflite::TensorType_UINT8: {
......@@ -117,18 +115,18 @@ class TfliteNodeParser {
}
break;
}
default: {
MS_LOG(ERROR) << "wrong tensor type";
return RET_ERROR;
}
}
return RET_OK;
}
protected:
bool isQuantizedModel();
protected:
const std::string &name;
bool quantizedModel;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_NODE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_H
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H
#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H
#include <string>
#include <unordered_map>
......@@ -46,5 +46,5 @@ class TfliteNodeRegister {
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_not_equal_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteNotEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_NotEqual;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_NOT_EQUAL_PARSER_H
#define LITE_TFLITE_NOT_EQUAL_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteNotEqualParser : public TfliteNodeParser {
public:
TfliteNotEqualParser() : TfliteNodeParser("NotEqual") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_NOT_EQUAL_PARSER_H
......@@ -25,6 +25,16 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(INFO) << "parse TfliteOneHotParser";
std::unique_ptr<schema::OneHotT> attr(new schema::OneHotT());
......@@ -46,11 +56,8 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
}
attr->axis = axis;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_OneHot;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_OneHot;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_p_relu_parser.h"
namespace mindspore {
namespace lite {
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "paser TflitePreluParser";
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_P_RELU_PARSER_H
#define LITE_TFLITE_P_RELU_PARSER_H
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TflitePreluParser : public TfliteNodeParser {
public:
TflitePreluParser() : TfliteNodeParser("Prelu") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_P_RELU_PARSER_H
......@@ -25,6 +25,16 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TflitePadParser";
std::unique_ptr<schema::PadT> attr(new schema::PadT());
const auto &tflite_attr = tfliteOp->builtin_options.AsPadOptions();
......@@ -40,11 +50,8 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Pad;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Pad;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -16,18 +16,42 @@
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_mean_pooling_parser.h"
#include <string>
#include "tools/converter/parser/tflite/tflite_pooling_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteMeanPoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser";
STATUS TflitePoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PoolingT> attr(new schema::PoolingT());
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "MeanPooling") == 0) {
MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser";
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
} else if (std::strcmp(node_name, "MaxPooling") == 0) {
MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser";
attr->poolingMode = schema::PoolMode_MAX_POOLING;
} else {
MS_LOG(ERROR) << "wrong pooling type";
return RET_ERROR;
}
const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
......@@ -38,22 +62,18 @@ STATUS TfliteMeanPoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC;
// attr->global
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
// calculate pad params
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Pooling;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Pooling;
op->primitive->value.value = attr.release();
return RET_OK;
}
TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TfliteMeanPoolingParser());
TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TfliteMaxPoolingParser());
} // namespace lite
} // namespace mindspore
......
......@@ -24,9 +24,9 @@
namespace mindspore {
namespace lite {
class TfliteMeanPoolingParser : public TfliteNodeParser {
class TflitePoolingParser : public TfliteNodeParser {
public:
TfliteMeanPoolingParser() : TfliteNodeParser("MeanPooling") {}
TflitePoolingParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
......@@ -35,6 +35,16 @@ class TfliteMeanPoolingParser : public TfliteNodeParser {
TensorCache *tensor_cache,
bool quantizedModel) override;
};
class TfliteMeanPoolingParser : public TflitePoolingParser {
public:
TfliteMeanPoolingParser() : TflitePoolingParser() {}
};
class TfliteMaxPoolingParser : public TflitePoolingParser {
public:
TfliteMaxPoolingParser() : TflitePoolingParser() {}
};
} // namespace lite
} // namespace mindspore
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_pow_parser.h"
namespace mindspore {
namespace lite {
STATUS TflitePowParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
MS_LOG(DEBUG) << "parse TflitePowParser";
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
attr->power = 0.0f;
attr->scale = 1.0f;
attr->shift = 0.0f;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TflitePowParser("Pow", new TflitePowParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_POW_PARSER_H
#define PREDICT_TFLITE_POW_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TflitePowParser : public TfliteNodeParser {
public:
TflitePowParser() : TfliteNodeParser("Pow") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_POW_PARSER_H
......@@ -27,16 +27,23 @@ STATUS TfliteRangeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteRangeParser";
std::unique_ptr<schema::RangeT> attr(new schema::RangeT());
attr->dType = 0;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Range;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Range;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
......@@ -27,14 +27,21 @@ STATUS TfliteRankParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteRankParser";
std::unique_ptr<schema::RankT> attr(new schema::RankT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Rank;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Rank;
op->primitive->value.value = attr.release();
return RET_OK;
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an AS
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_real_div_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteRealDivParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteRealDivParser";
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_RealDiv;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteRealDivParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_REAL_DIV_PARSER_H
#define LITE_TFLITE_REAL_DIV_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteRealDivParser : public TfliteNodeParser {
public:
TfliteRealDivParser() : TfliteNodeParser("RealDiv") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_REAL_DIV_PARSER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_reduce_any_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteReduceAnyParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteReduceAnyParser";
std::unique_ptr<schema::ReduceT> attr(new schema::ReduceT());
const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name << " attr failed";
return RET_NULL_PTR;
}
attr->keepDims = tflite_attr->keep_dims;
// attr->mode = schema::;
MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now";
return RET_NOT_FIND_OP;
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) {
MS_LOG(ERROR) << "get reduce_any->axes failed";
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Reduce;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteReduceAnyParser("ReduceAny", new TfliteReduceAnyParser());
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_REDUCE_ANY_PARSER_H
#define PREDICT_TFLITE_REDUCE_ANY_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteReduceAnyParser : public TfliteNodeParser {
public:
TfliteReduceAnyParser() : TfliteNodeParser("ReduceAny") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_REDUCE_ANY_PARSER_H
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册