提交 aef72653 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5606 tflite custom parser

Merge pull request !5606 from 徐安越/master
......@@ -428,7 +428,7 @@ build_flatbuffer() {
if [[ ! -f "${FLATC}" ]]; then
git submodule update --init --recursive third_party/flatbuffers
cd ${BASEPATH}/third_party/flatbuffers
rm -rf build && mkdir -pv build && cd build && cmake .. && make -j$THREAD_NUM
rm -rf build && mkdir -pv build && cd build && cmake -DFLATBUFFERS_BUILD_SHAREDLIB=ON .. && make -j$THREAD_NUM
gene_flatbuffer
fi
if [[ "${INC_BUILD}" == "off" ]]; then
......
......@@ -67,6 +67,7 @@ else ()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME})
install(FILES ${TOP_DIR}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 DESTINATION ${PROTOBF_DIR}/lib RENAME libprotobuf.so.19 COMPONENT ${COMPONENT_NAME})
install(FILES ${TOP_DIR}/third_party/flatbuffers/build/libflatbuffers.so.1.11.0 DESTINATION ${FLATBF_DIR}/lib RENAME libflatbuffers.so.1 COMPONENT ${COMPONENT_NAME})
endif ()
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
......@@ -89,4 +90,4 @@ else ()
set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output/tmp)
endif()
set(CPACK_PACKAGE_CHECKSUM SHA256)
include(CPack)
\ No newline at end of file
include(CPack)
......@@ -355,6 +355,7 @@ table DetectionPostProcess {
MaxClassesPreDetection: long;
NumClasses: long;
UseRegularNms: bool;
OutQuantized: bool;
}
table FullConnection {
......
......@@ -120,6 +120,7 @@
#include "src/ops/tuple_get_item.h"
#include "src/ops/l2_norm.h"
#include "src/ops/sparse_to_dense.h"
#include "src/ops/detection_post_process.h"
namespace mindspore {
namespace lite {
......@@ -469,6 +470,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return new L2Norm(primitive);
case schema::PrimitiveType_SparseToDense:
return new SparseToDense(primitive);
case schema::PrimitiveType_DetectionPostProcess:
return new DetectionPostProcess(primitive);
default:
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : "
<< schema::EnumNamePrimitiveType(op_type);
......@@ -681,6 +684,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi
return NewPrimitiveC<L2Norm>(primitive);
case schema::PrimitiveType_SparseToDense:
return NewPrimitiveC<SparseToDense>(primitive);
case schema::PrimitiveType_DetectionPostProcess:
return NewPrimitiveC<DetectionPostProcess>(primitive);
default:
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitive : "
<< schema::EnumNamePrimitiveType(op_type);
......
......@@ -377,7 +377,7 @@ tar -zxf mindspore-lite-${version}-runtime-x86-${process_unit_x86}.tar.gz || exi
tar -zxf mindspore-lite-${version}-converter-ubuntu.tar.gz || exit 1
cd ${convertor_path}/mindspore-lite-${version}-converter-ubuntu || exit 1
cp converter/converter_lite ./ || exit 1
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib/:./third_party/protobuf/lib
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib/:./third_party/protobuf/lib:./third_party/flatbuffers/lib
# Convert the models
cd ${convertor_path}/mindspore-lite-${version}-converter-ubuntu || exit 1
......
......@@ -99,8 +99,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
auto output_quant_params = primitive->GetOutputQuantParams();
if (output_quant_params.empty()) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
MS_LOG(ERROR) << "node: " << dst_node->name << " output quant params is empty";
return RET_ERROR;
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
}
} else {
for (auto output_quant_param : output_quant_params[0]) {
......
......@@ -75,10 +75,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
if (node->inputIndex.at(inputIndexIdx) == graphInIdx) {
auto nodeName = (*iter)->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) {
STATUS status = RET_OK;
// insert dtype cast node between input tensor and input node
......@@ -108,11 +107,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
auto &graphOutIdxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
auto nodeName = (*iter)->name;
MS_ASSERT(node != nullptr);
for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) {
if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) {
for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) {
// insert transNode
STATUS status = RET_OK;
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
......@@ -135,7 +133,6 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
continue;
}
auto &node = *iter;
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
continue;
}
......@@ -143,8 +140,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
if (GetCNodeTType(**iter) == PrimitiveType_Shape) {
needInsertPost = false;
}
auto nodeName = node->name;
if (node->inputIndex.size() < kMinInputNum) {
auto nodeName = (*iter)->name;
if ((*iter)->inputIndex.size() < kMinInputNum) {
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
......
file(GLOB_RECURSE TFLITE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.cc
)
ADD_DEFINITIONS(-DFLATBUFFERS_LOCALE_INDEPENDENT=1)
find_library(FLATBUFFERS_LIBRARY flatbuffers HINTS ${TOP_DIR}/third_party/flatbuffers/build)
add_library(tflite_parser_mid OBJECT
${TFLITE_SRC_LIST}
)
target_link_libraries(tflite_parser_mid ${FLATBUFFERS_LIBRARY})
......@@ -18,6 +18,8 @@
#include <vector>
#include <memory>
#include <map>
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/flexbuffers.h"
namespace mindspore {
namespace lite {
......@@ -39,15 +41,42 @@ STATUS TfliteCustomParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
std::unique_ptr<schema::CustomT> attr = std::make_unique<schema::CustomT>();
std::unique_ptr<schema::DetectionPostProcessT> attr = std::make_unique<schema::DetectionPostProcessT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &custom_attr = tflite_op->custom_options;
attr->custom = custom_attr;
op->primitive->value.type = schema::PrimitiveType_Custom;
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
attr->format = schema::Format_NHWC;
attr->inputSize = tflite_op->inputs.size();
attr->hScale = attr_map["h_scale"].AsFloat();
attr->wScale = attr_map["w_scale"].AsFloat();
attr->xScale = attr_map["x_scale"].AsFloat();
attr->yScale = attr_map["y_scale"].AsFloat();
attr->NmsIouThreshold = attr_map["nms_iou_threshold"].AsFloat();
attr->NmsScoreThreshold = attr_map["nms_score_threshold"].AsFloat();
attr->MaxDetections = attr_map["max_detections"].AsInt32();
if (attr_map["detections_per_class"].IsNull()) {
attr->DetectionsPreClass = 100;
} else {
attr->DetectionsPreClass = attr_map["detections_per_class"].AsInt32();
}
attr->MaxClassesPreDetection = attr_map["max_classes_per_detection"].AsInt32();
attr->NumClasses = attr_map["num_classes"].AsInt32();
if (attr_map["use_regular_nms"].IsNull()) {
attr->UseRegularNms = false;
} else {
attr->UseRegularNms = attr_map["use_regular_nms"].AsBool();
}
if (attr_map["_output_quantized"].IsNull()) {
attr->OutQuantized = false;
} else {
attr->OutQuantized = attr_map["_output_quantized"].AsBool();
}
op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess;
op->primitive->value.value = attr.release();
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册