From 1527af23d9d6d25c7abfe851087cf6828e37ea8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Wed, 17 Jun 2020 16:51:02 +0800 Subject: [PATCH] [Feature] add flatbuffers, test=develop (#3790) * add flatbuffers, test=develop * fix cmake codes, test=develop --- .gitmodules | 3 + CMakeLists.txt | 1 + cmake/external/flatbuffers.cmake | 111 +++++++++++ lite/api/CMakeLists.txt | 2 +- lite/model_parser/CMakeLists.txt | 1 + lite/model_parser/flatbuffers/CMakeLists.txt | 11 ++ lite/model_parser/flatbuffers/block_desc.cc | 35 ++++ lite/model_parser/flatbuffers/block_desc.h | 69 +++++++ lite/model_parser/flatbuffers/framework.fbs | 172 ++++++++++++++++++ lite/model_parser/flatbuffers/op_desc.cc | 114 ++++++++++++ lite/model_parser/flatbuffers/op_desc.h | 132 ++++++++++++++ lite/model_parser/flatbuffers/program_desc.cc | 29 +++ lite/model_parser/flatbuffers/program_desc.h | 53 ++++++ lite/model_parser/flatbuffers/var_desc.cc | 15 ++ lite/model_parser/flatbuffers/var_desc.h | 59 ++++++ third-party/flatbuffers | 1 + 16 files changed, 807 insertions(+), 1 deletion(-) create mode 100644 cmake/external/flatbuffers.cmake create mode 100644 lite/model_parser/flatbuffers/CMakeLists.txt create mode 100644 lite/model_parser/flatbuffers/block_desc.cc create mode 100644 lite/model_parser/flatbuffers/block_desc.h create mode 100644 lite/model_parser/flatbuffers/framework.fbs create mode 100644 lite/model_parser/flatbuffers/op_desc.cc create mode 100644 lite/model_parser/flatbuffers/op_desc.h create mode 100644 lite/model_parser/flatbuffers/program_desc.cc create mode 100644 lite/model_parser/flatbuffers/program_desc.h create mode 100644 lite/model_parser/flatbuffers/var_desc.cc create mode 100644 lite/model_parser/flatbuffers/var_desc.h create mode 160000 third-party/flatbuffers diff --git a/.gitmodules b/.gitmodules index 107036c702..37af6a7245 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "third-party/protobuf-host"] path = third-party/protobuf-host url = https://github.com/protocolbuffers/protobuf.git +[submodule "third-party/flatbuffers"] + path = third-party/flatbuffers + url = https://github.com/google/flatbuffers.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ac227f015..9188dd83d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -168,6 +168,7 @@ if(LITE_WITH_RKNPU) include(device/rknpu) endif() +include(external/flatbuffers) # for mobile if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/cmake/external/flatbuffers.cmake b/cmake/external/flatbuffers.cmake new file mode 100644 index 0000000000..fcd286254d --- /dev/null +++ b/cmake/external/flatbuffers.cmake @@ -0,0 +1,111 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +# Introduce variables: +# * CMAKE_INSTALL_LIBDIR +INCLUDE(GNUInstallDirs) +SET(LIBDIR "lib") +if(CMAKE_INSTALL_LIBDIR MATCHES ".*lib64$") + SET(LIBDIR "lib64") +endif() + +SET(FLATBUFFERS_SOURCES_DIR ${CMAKE_SOURCE_DIR}/third-party/flatbuffers) +SET(FLATBUFFERS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flatbuffers) +SET(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_INSTALL_DIR}/include" CACHE PATH "flatbuffers include directory." FORCE) +IF(WIN32) + set(FLATBUFFERS_LIBRARIES "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib" CACHE FILEPATH "FLATBUFFERS_LIBRARIES" FORCE) +ELSE(WIN32) + set(FLATBUFFERS_LIBRARIES "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.a" CACHE FILEPATH "FLATBUFFERS_LIBRARIES" FORCE) +ENDIF(WIN32) + +INCLUDE_DIRECTORIES(${FLATBUFFERS_INCLUDE_DIR}) + +if(NOT HOST_CXX_COMPILER) + set(HOST_CXX_COMPILER ${CMAKE_CXX_COMPILER}) + set(HOST_C_COMPILER ${CMAKE_C_COMPILER}) +endif() + +SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${HOST_C_COMPILER}") + +ExternalProject_Add( + extern_flatbuffers + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "" + GIT_TAG "v1.12.0" + SOURCE_DIR ${FLATBUFFERS_SOURCES_DIR} + PREFIX ${FLATBUFFERS_INCLUDE_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS -DBUILD_STATIC_LIBS=ON + -DCMAKE_INSTALL_PREFIX=${FLATBUFFERS_INSTALL_DIR} + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DBUILD_TESTING=OFF + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${CROSS_COMPILE_CMAKE_ARGS} + ${OPTIONAL_ARGS} + ${EXTERNAL_OPTIONAL_ARGS} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${FLATBUFFERS_INSTALL_DIR} + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} +) +IF(WIN32) + IF(NOT EXISTS "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib") + add_custom_command(TARGET extern_flatbuffers POST_BUILD + COMMAND cmake -E copy ${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/flatbuffers_static.lib ${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib + ) + ENDIF() +ENDIF(WIN32) +ADD_LIBRARY(flatbuffers STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET flatbuffers PROPERTY IMPORTED_LOCATION ${FLATBUFFERS_LIBRARIES}) +ADD_DEPENDENCIES(flatbuffers extern_flatbuffers) + +SET(FLATBUFFERS_FLATC_EXECUTABLE ${FLATBUFFERS_INSTALL_DIR}/bin/flatc) + +function(register_generated_output file_name) + get_property(tmp GLOBAL PROPERTY FBS_GENERATED_OUTPUTS) + list(APPEND tmp ${file_name}) + set_property(GLOBAL PROPERTY FBS_GENERATED_OUTPUTS ${tmp}) +endfunction(register_generated_output) + +function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT) + if(FLATBUFFERS_BUILD_LEGACY) + set(OPT ${OPT};--cpp-std c++0x) + else() + # --cpp-std is defined by flatc default settings. + endif() + message(STATUS "`${SRC_FBS}`: add generation of C++ code with '${OPT}'") + get_filename_component(SRC_FBS_DIR ${SRC_FBS} PATH) + message(STATUS "SRC_FBS_DIR: ${SRC_FBS_DIR}") + string(REGEX REPLACE "\\.fbs$" "_generated.h" GEN_HEADER ${SRC_FBS}) + add_custom_command( + OUTPUT ${GEN_HEADER} + COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}" + --cpp --gen-mutable --gen-object-api --reflect-names + --cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs + ${OPT} + -I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test" + -o "${SRC_FBS_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS}" + DEPENDS flatbuffers + COMMENT "Run generation: '${GEN_HEADER}'") + include_directories(${FLATBUFFERS_INCLUDE_DIR}) + register_generated_output(${GEN_HEADER}) + add_custom_target(${TARGET} ALL DEPENDS ${GEN_HEADER}) +endfunction() + +set(FRAMEWORK_FBS_DIR "lite/model_parser/flatbuffers") +set(FRAMEWORK_SCHEMA_PATH "${FRAMEWORK_FBS_DIR}/framework.fbs") +compile_flatbuffers_schema_to_cpp_opt(framework_fbs_header ${FRAMEWORK_SCHEMA_PATH} "--no-includes;--gen-compare;--force-empty") diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 85744f5cac..3c7f4b23fe 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -70,7 +70,7 @@ else() set(TARGET_COMIPILE_FLAGS "${TARGET_COMIPILE_FLAGS} -flto") endif() set_target_properties(paddle_light_api_shared PROPERTIES COMPILE_FLAGS "${TARGET_COMIPILE_FLAGS}") - add_dependencies(paddle_light_api_shared op_list_h kernel_list_h) + add_dependencies(paddle_light_api_shared op_list_h kernel_list_h framework_fbs_header) if (LITE_WITH_NPU) # Need to add HIAI runtime libs (libhiai.so) dependency target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs}) diff --git a/lite/model_parser/CMakeLists.txt b/lite/model_parser/CMakeLists.txt index 34d524c5c1..4e53b73886 100644 --- a/lite/model_parser/CMakeLists.txt +++ b/lite/model_parser/CMakeLists.txt @@ -3,6 +3,7 @@ if (NOT LITE_ON_TINY_PUBLISH) endif() add_subdirectory(cpp) add_subdirectory(naive_buffer) +add_subdirectory(flatbuffers) #lite_cc_library(runtime_lite SRCS runtime.cc) diff --git a/lite/model_parser/flatbuffers/CMakeLists.txt b/lite/model_parser/flatbuffers/CMakeLists.txt new file mode 100644 index 0000000000..7b935ba7d3 --- /dev/null +++ b/lite/model_parser/flatbuffers/CMakeLists.txt @@ -0,0 +1,11 @@ +function(lite_fbs_library TARGET) + set(multiValueArgs SRCS FBS_DEPS) + cmake_parse_arguments(args "" "" "${multiValueArgs}" ${ARGN}) + lite_cc_library(${TARGET} SRCS ${args_SRCS}) + add_dependencies(${TARGET} ${args_FBS_DEPS}) +endfunction() + +lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header) +lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header) +lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header) +lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header) diff --git a/lite/model_parser/flatbuffers/block_desc.cc b/lite/model_parser/flatbuffers/block_desc.cc new file mode 100644 index 0000000000..fc43af6d62 --- /dev/null +++ b/lite/model_parser/flatbuffers/block_desc.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/flatbuffers/block_desc.h" + +namespace paddle { +namespace lite { +namespace fbs { + +template <> +proto::VarDesc* BlockDesc::GetVar(int32_t idx) { + CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; + return const_cast(desc_->vars()->Get(idx)); +} + +template <> +proto::OpDesc* BlockDesc::GetOp(int32_t idx) { + CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; + return const_cast(desc_->ops()->Get(idx)); +} + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/block_desc.h b/lite/model_parser/flatbuffers/block_desc.h new file mode 100644 index 0000000000..2498776b2e --- /dev/null +++ b/lite/model_parser/flatbuffers/block_desc.h @@ -0,0 +1,69 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/model_parser/base/block_desc.h" +#include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace fbs { + +class BlockDesc : public BlockDescReadAPI { + public: + explicit BlockDesc(proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); } + + int32_t Idx() const override { return desc_->idx(); } + + int32_t ParentIdx() const override { return desc_->parent_idx(); } + + size_t VarsSize() const override { return desc_->vars()->size(); } + + template + T* GetVar(int32_t idx); + + template + T const* GetVar(int32_t idx) const { + return GetVar(idx); + } + + size_t OpsSize() const override { + CHECK(desc_); + CHECK(desc_->ops()); + return desc_->ops()->size(); + } + + template + T* GetOp(int32_t idx); + + template + T const* GetOp(int32_t idx) const { + return GetOp(idx); + } + + int32_t ForwardBlockIdx() const override { + return desc_->forward_block_idx(); + } + + BlockDesc() = delete; + + private: + proto::BlockDesc* desc_; // not_own +}; + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/framework.fbs b/lite/model_parser/flatbuffers/framework.fbs new file mode 100644 index 0000000000..90f6e62608 --- /dev/null +++ b/lite/model_parser/flatbuffers/framework.fbs @@ -0,0 +1,172 @@ +// Generated from framework.proto + +namespace paddle.lite.fbs.proto; + +enum AttrType : int { + INT = 0, + FLOAT = 1, + STRING = 2, + INTS = 3, + FLOATS = 4, + STRINGS = 5, + BOOLEAN = 6, + BOOLEANS = 7, + BLOCK = 8, + LONG = 9, + BLOCKS = 10, + LONGS = 11, +} + +namespace paddle.lite.fbs.proto.VarType_; + +enum Type : int { + BOOL = 0, + INT16 = 1, + INT32 = 2, + INT64 = 3, + FP16 = 4, + FP32 = 5, + FP64 = 6, + LOD_TENSOR = 7, + SELECTED_ROWS = 8, + FEED_MINIBATCH = 9, + FETCH_LIST = 10, + STEP_SCOPES = 11, + LOD_RANK_TABLE = 12, + LOD_TENSOR_ARRAY = 13, + PLACE_LIST = 14, + READER = 15, + RAW = 17, + TUPLE = 18, + SIZE_T = 19, + UINT8 = 20, + INT8 = 21, +} + +namespace paddle.lite.fbs.proto.CompatibleInfo_; + +enum Type : int { + COMPATIBLE = 0, + DEFINITELY_NOT = 1, + POSSIBLE = 2, + BUG_FIX = 3, + PRECISION_CHANGE = 4, +} + +namespace paddle.lite.fbs.proto; + +table Version { + version:long; +} + +table OpDesc { + type:string (required); + inputs:[paddle.lite.fbs.proto.OpDesc_.Var]; + outputs:[paddle.lite.fbs.proto.OpDesc_.Var]; + attrs:[paddle.lite.fbs.proto.OpDesc_.Attr]; + is_target:bool; +} + +namespace paddle.lite.fbs.proto.OpDesc_; + +table Attr { + name:string (required, key); + type:paddle.lite.fbs.proto.AttrType; + i:int; + f:float; + s:string; + ints:[int]; + floats:[float]; + strings:[string]; + b:bool; + bools:[bool]; + block_idx:int; + l:long; + blocks_idx:[int]; + longs:[long]; +} + +table Var { + parameter:string (required, key); + arguments:[string]; +} + +namespace paddle.lite.fbs.proto; + +table VarType { + type:paddle.lite.fbs.proto.VarType_.Type; + selected_rows:paddle.lite.fbs.proto.VarType_.TensorDesc; + lod_tensor:paddle.lite.fbs.proto.VarType_.LoDTensorDesc; + tensor_array:paddle.lite.fbs.proto.VarType_.LoDTensorArrayDesc; + reader:paddle.lite.fbs.proto.VarType_.ReaderDesc; + tuple:paddle.lite.fbs.proto.VarType_.Tuple; +} + +namespace paddle.lite.fbs.proto.VarType_; + +table TensorDesc { + data_type:paddle.lite.fbs.proto.VarType_.Type; + dims:[long]; +} + +table LoDTensorDesc { + tensor:paddle.lite.fbs.proto.VarType_.TensorDesc (required); + lod_level:int; +} + +table LoDTensorArrayDesc { + tensor:paddle.lite.fbs.proto.VarType_.TensorDesc (required); + lod_level:int; +} + +table ReaderDesc { + lod_tensor:[paddle.lite.fbs.proto.VarType_.LoDTensorDesc]; +} + +table Tuple { + element_type:[paddle.lite.fbs.proto.VarType_.Type]; +} + +namespace paddle.lite.fbs.proto; + +table VarDesc { + name:string (required, key); + type:paddle.lite.fbs.proto.VarType (required); + persistable:bool; + need_check_feed:bool; +} + +table BlockDesc { + idx:int; + parent_idx:int; + vars:[paddle.lite.fbs.proto.VarDesc]; + ops:[paddle.lite.fbs.proto.OpDesc]; + forward_block_idx:int = -1; +} + +table CompatibleInfo { + version:string (required); + type:paddle.lite.fbs.proto.CompatibleInfo_.Type; +} + +table OpCompatibleMap { + pair:[paddle.lite.fbs.proto.OpCompatibleMap_.OpCompatiblePair]; + default_required_version:string; +} + +namespace paddle.lite.fbs.proto.OpCompatibleMap_; + +table OpCompatiblePair { + op_name:string (required, key); + compatible_info:paddle.lite.fbs.proto.CompatibleInfo (required); +} + +namespace paddle.lite.fbs.proto; + +table ProgramDesc { + blocks:[paddle.lite.fbs.proto.BlockDesc]; + version:paddle.lite.fbs.proto.Version; + op_compatible_map:paddle.lite.fbs.proto.OpCompatibleMap; +} + +root_type paddle.lite.fbs.proto.ProgramDesc; diff --git a/lite/model_parser/flatbuffers/op_desc.cc b/lite/model_parser/flatbuffers/op_desc.cc new file mode 100644 index 0000000000..5c96cae17e --- /dev/null +++ b/lite/model_parser/flatbuffers/op_desc.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/flatbuffers/op_desc.h" + +namespace paddle { +namespace lite { +namespace fbs { + +template <> +std::string OpDesc::GetAttr(const std::string& name) const { + const auto& it = desc_->attrs()->LookupByKey(name.c_str()); + if (!it->s()) { + return std::string(); + } + return it->s()->str(); +} + +template <> +std::string OpDesc::GetAttr(size_t idx) const { + const auto& it = desc_->attrs()->Get(idx); + if (!it->s()) { + return std::string(); + } + return it->s()->str(); +} + +template <> +std::vector OpDesc::GetAttr>( + const std::string& name) const { + const auto& it = desc_->attrs()->LookupByKey(name.c_str()); + CHECK(it) << "Attr " << name << "does not exist."; + std::vector res; + if (it->strings()) { + res.reserve(it->strings()->size()); + for (const auto& v : *it->strings()) { + res.push_back(v->str()); + } + } + return res; +} + +template <> +std::vector OpDesc::GetAttr>( + size_t idx) const { + const auto& it = desc_->attrs()->Get(idx); + CHECK(it) << "Attr " << idx << "does not exist."; + std::vector res; + if (it->strings()) { + res.reserve(it->strings()->size()); + for (const auto& v : *it->strings()) { + res.push_back(v->str()); + } + } + return res; +} + +#define GET_ATTR_IMPL(T, fb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string& name) const { \ + const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ + return it->fb_f__(); \ + } \ + template <> \ + T OpDesc::GetAttr(size_t idx) const { \ + const auto& it = desc_->attrs()->Get(idx); \ + return it->fb_f__(); \ + } + +#define GET_ATTRS_IMPL(T, fb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string& name) const { \ + const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ + T res; \ + res.reserve(it->fb_f__()->size()); \ + for (const auto& v : *it->fb_f__()) { \ + res.push_back(v); \ + } \ + return res; \ + } \ + template <> \ + T OpDesc::GetAttr(size_t idx) const { \ + const auto& it = desc_->attrs()->Get(idx); \ + T res; \ + res.reserve(it->fb_f__()->size()); \ + for (const auto& v : *it->fb_f__()) { \ + res.push_back(v); \ + } \ + return res; \ + } + +GET_ATTR_IMPL(int32_t, i); +GET_ATTR_IMPL(int16_t, block_idx); +GET_ATTR_IMPL(float, f); +GET_ATTR_IMPL(bool, b); +GET_ATTR_IMPL(int64_t, l); +GET_ATTRS_IMPL(std::vector, ints); +GET_ATTRS_IMPL(std::vector, floats); +GET_ATTRS_IMPL(std::vector, longs); + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/op_desc.h b/lite/model_parser/flatbuffers/op_desc.h new file mode 100644 index 0000000000..2cb7eed9ad --- /dev/null +++ b/lite/model_parser/flatbuffers/op_desc.h @@ -0,0 +1,132 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "lite/model_parser/base/op_desc.h" +#include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace fbs { + +class OpDesc : public OpDescReadAPI { + public: + explicit OpDesc(proto::OpDesc* desc) : desc_(desc) { CHECK(desc_); } + + std::string Type() const override { return desc_->type()->str(); } + + // Get the arguments of parameter called `param` + std::vector Input(const std::string& param) const override { + const auto& var = desc_->inputs()->LookupByKey(param.c_str()); + std::vector args_vec; + if (var->arguments()) { + args_vec.reserve(var->arguments()->size()); + for (const auto& in : *var->arguments()) { + args_vec.push_back(in->str()); + } + } + return args_vec; + } + + std::vector InputArgumentNames() const override { + const auto& vars = desc_->inputs(); + std::vector input_names_vec; + if (vars) { + input_names_vec.reserve(vars->size()); + for (const auto& in : *vars) { + input_names_vec.push_back(in->parameter()->str()); + } + } + return input_names_vec; + } + + std::vector Output(const std::string& param) const override { + const auto& var = desc_->outputs()->LookupByKey(param.c_str()); + std::vector args_vec; + if (var->arguments()) { + args_vec.reserve(var->arguments()->size()); + for (const auto& out : *var->arguments()) { + args_vec.push_back(out->str()); + } + } + return args_vec; + } + + std::vector OutputArgumentNames() const override { + const auto& vars = desc_->outputs(); + std::vector output_names_vec; + if (vars) { + output_names_vec.reserve(vars->size()); + for (const auto& out : *vars) { + output_names_vec.push_back(out->parameter()->str()); + } + } + return output_names_vec; + } + + bool HasAttr(const std::string& name) const override { + return desc_->attrs()->LookupByKey(name.c_str()) == nullptr; + } + + size_t AttrsSize() const { return desc_->attrs()->size(); } + + std::string AttrName(size_t idx) const { + return desc_->attrs()->Get(idx)->name()->str(); + } + + OpDescAPI::AttrType GetAttrType(const std::string& name) const override { + const auto& attr = desc_->attrs()->LookupByKey(name.c_str()); + CHECK(attr); + return static_cast(attr->type()); + } + + OpDescAPI::AttrType GetAttrType(size_t idx) const { + const auto& attr = desc_->attrs()->Get(idx); + CHECK(attr); + return static_cast(attr->type()); + } + + std::vector AttrNames() const override { + const auto& attrs = desc_->attrs(); + std::vector attr_names_vec; + if (attrs) { + attr_names_vec.reserve(attrs->size()); + for (const auto& attr : *attrs) { + attr_names_vec.push_back(attr->name()->str()); + } + } + return attr_names_vec; + } + + template + T GetAttr(const std::string& name) const; + + template + T GetAttr(size_t idx) const; + + OpDesc() = delete; + + private: + proto::OpDesc* desc_; +}; + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/program_desc.cc b/lite/model_parser/flatbuffers/program_desc.cc new file mode 100644 index 0000000000..36429103a7 --- /dev/null +++ b/lite/model_parser/flatbuffers/program_desc.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/flatbuffers/program_desc.h" + +namespace paddle { +namespace lite { +namespace fbs { + +template <> +proto::BlockDesc* ProgramDesc::GetBlock(int32_t idx) { + CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; + return const_cast(desc_->blocks()->Get(idx)); +} + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/program_desc.h b/lite/model_parser/flatbuffers/program_desc.h new file mode 100644 index 0000000000..db3cd936ab --- /dev/null +++ b/lite/model_parser/flatbuffers/program_desc.h @@ -0,0 +1,53 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/model_parser/base/program_desc.h" +#include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace fbs { + +class ProgramDesc : public ProgramDescReadAPI { + public: + explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); } + + size_t BlocksSize() const override { return desc_->blocks()->size(); } + + template + T *GetBlock(int32_t idx); + + template + T const *GetBlock(int32_t idx) const { + return GetBlock(idx); + } + + bool HasVersion() const override { return desc_->version() != nullptr; } + + int64_t Version() const override { + CHECK(HasVersion()); + return desc_->version()->version(); + } + + private: + proto::ProgramDesc *desc_; // not_own +}; + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/var_desc.cc b/lite/model_parser/flatbuffers/var_desc.cc new file mode 100644 index 0000000000..a629ffd5e3 --- /dev/null +++ b/lite/model_parser/flatbuffers/var_desc.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/flatbuffers/var_desc.h" diff --git a/lite/model_parser/flatbuffers/var_desc.h b/lite/model_parser/flatbuffers/var_desc.h new file mode 100644 index 0000000000..67402dbb20 --- /dev/null +++ b/lite/model_parser/flatbuffers/var_desc.h @@ -0,0 +1,59 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/model_parser/base/var_desc.h" +#include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace fbs { + +class VarDesc : public VarDescReadAPI { + public: + explicit VarDesc(proto::VarDesc* desc) : desc_(desc) {} + + std::string Name() const override { return desc_->name()->str(); } + + VarDescAPI::Type GetType() const override { + return static_cast(desc_->type()->type()); + } + + bool Persistable() const override { return desc_->persistable(); } + + std::vector GetShape() const override { + CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR); + const auto& dims = desc_->type()->lod_tensor()->tensor()->dims(); + std::vector dims_vec; + dims_vec.reserve(dims->size()); + for (const auto& dim : *dims) { + dims_vec.push_back(dim); + } + return dims_vec; + } + + VarDesc() = delete; + + private: + proto::VarDesc* desc_; +}; + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/third-party/flatbuffers b/third-party/flatbuffers new file mode 160000 index 0000000000..ac203b2092 --- /dev/null +++ b/third-party/flatbuffers @@ -0,0 +1 @@ +Subproject commit ac203b20926b13a35ff85277d2e5d3c38698eee8 -- GitLab