未验证 提交 1527af23 编写于 作者: 石晓伟 提交者: GitHub

[Feature] add flatbuffers, test=develop (#3790)

* add flatbuffers, test=develop

* fix cmake codes, test=develop
上级 2f317a5e
......@@ -10,3 +10,6 @@
[submodule "third-party/protobuf-host"]
path = third-party/protobuf-host
url = https://github.com/protocolbuffers/protobuf.git
[submodule "third-party/flatbuffers"]
path = third-party/flatbuffers
url = https://github.com/google/flatbuffers.git
......@@ -168,6 +168,7 @@ if(LITE_WITH_RKNPU)
include(device/rknpu)
endif()
include(external/flatbuffers)
# for mobile
if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
# Introduce variables:
# * CMAKE_INSTALL_LIBDIR
INCLUDE(GNUInstallDirs)
SET(LIBDIR "lib")
if(CMAKE_INSTALL_LIBDIR MATCHES ".*lib64$")
SET(LIBDIR "lib64")
endif()
SET(FLATBUFFERS_SOURCES_DIR ${CMAKE_SOURCE_DIR}/third-party/flatbuffers)
SET(FLATBUFFERS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flatbuffers)
SET(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_INSTALL_DIR}/include" CACHE PATH "flatbuffers include directory." FORCE)
IF(WIN32)
set(FLATBUFFERS_LIBRARIES "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib" CACHE FILEPATH "FLATBUFFERS_LIBRARIES" FORCE)
ELSE(WIN32)
set(FLATBUFFERS_LIBRARIES "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.a" CACHE FILEPATH "FLATBUFFERS_LIBRARIES" FORCE)
ENDIF(WIN32)
INCLUDE_DIRECTORIES(${FLATBUFFERS_INCLUDE_DIR})
if(NOT HOST_CXX_COMPILER)
set(HOST_CXX_COMPILER ${CMAKE_CXX_COMPILER})
set(HOST_C_COMPILER ${CMAKE_C_COMPILER})
endif()
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${HOST_C_COMPILER}")
ExternalProject_Add(
extern_flatbuffers
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY ""
GIT_TAG "v1.12.0"
SOURCE_DIR ${FLATBUFFERS_SOURCES_DIR}
PREFIX ${FLATBUFFERS_INCLUDE_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DBUILD_STATIC_LIBS=ON
-DCMAKE_INSTALL_PREFIX=${FLATBUFFERS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${CROSS_COMPILE_CMAKE_ARGS}
${OPTIONAL_ARGS}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${FLATBUFFERS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib")
add_custom_command(TARGET extern_flatbuffers POST_BUILD
COMMAND cmake -E copy ${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/flatbuffers_static.lib ${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(flatbuffers STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET flatbuffers PROPERTY IMPORTED_LOCATION ${FLATBUFFERS_LIBRARIES})
ADD_DEPENDENCIES(flatbuffers extern_flatbuffers)
SET(FLATBUFFERS_FLATC_EXECUTABLE ${FLATBUFFERS_INSTALL_DIR}/bin/flatc)
function(register_generated_output file_name)
get_property(tmp GLOBAL PROPERTY FBS_GENERATED_OUTPUTS)
list(APPEND tmp ${file_name})
set_property(GLOBAL PROPERTY FBS_GENERATED_OUTPUTS ${tmp})
endfunction(register_generated_output)
function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
if(FLATBUFFERS_BUILD_LEGACY)
set(OPT ${OPT};--cpp-std c++0x)
else()
# --cpp-std is defined by flatc default settings.
endif()
message(STATUS "`${SRC_FBS}`: add generation of C++ code with '${OPT}'")
get_filename_component(SRC_FBS_DIR ${SRC_FBS} PATH)
message(STATUS "SRC_FBS_DIR: ${SRC_FBS_DIR}")
string(REGEX REPLACE "\\.fbs$" "_generated.h" GEN_HEADER ${SRC_FBS})
add_custom_command(
OUTPUT ${GEN_HEADER}
COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}"
--cpp --gen-mutable --gen-object-api --reflect-names
--cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs
${OPT}
-I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test"
-o "${SRC_FBS_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS}"
DEPENDS flatbuffers
COMMENT "Run generation: '${GEN_HEADER}'")
include_directories(${FLATBUFFERS_INCLUDE_DIR})
register_generated_output(${GEN_HEADER})
add_custom_target(${TARGET} ALL DEPENDS ${GEN_HEADER})
endfunction()
set(FRAMEWORK_FBS_DIR "lite/model_parser/flatbuffers")
set(FRAMEWORK_SCHEMA_PATH "${FRAMEWORK_FBS_DIR}/framework.fbs")
compile_flatbuffers_schema_to_cpp_opt(framework_fbs_header ${FRAMEWORK_SCHEMA_PATH} "--no-includes;--gen-compare;--force-empty")
......@@ -70,7 +70,7 @@ else()
set(TARGET_COMIPILE_FLAGS "${TARGET_COMIPILE_FLAGS} -flto")
endif()
set_target_properties(paddle_light_api_shared PROPERTIES COMPILE_FLAGS "${TARGET_COMIPILE_FLAGS}")
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h)
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h framework_fbs_header)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs})
......
......@@ -3,6 +3,7 @@ if (NOT LITE_ON_TINY_PUBLISH)
endif()
add_subdirectory(cpp)
add_subdirectory(naive_buffer)
add_subdirectory(flatbuffers)
#lite_cc_library(runtime_lite SRCS runtime.cc)
......
function(lite_fbs_library TARGET)
set(multiValueArgs SRCS FBS_DEPS)
cmake_parse_arguments(args "" "" "${multiValueArgs}" ${ARGN})
lite_cc_library(${TARGET} SRCS ${args_SRCS})
add_dependencies(${TARGET} ${args_FBS_DEPS})
endfunction()
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/block_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
template <>
proto::VarDesc* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return const_cast<proto::VarDesc*>(desc_->vars()->Get(idx));
}
template <>
proto::OpDesc* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return const_cast<proto::OpDesc*>(desc_->ops()->Get(idx));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class BlockDesc : public BlockDescReadAPI {
public:
explicit BlockDesc(proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); }
int32_t Idx() const override { return desc_->idx(); }
int32_t ParentIdx() const override { return desc_->parent_idx(); }
size_t VarsSize() const override { return desc_->vars()->size(); }
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
}
size_t OpsSize() const override {
CHECK(desc_);
CHECK(desc_->ops());
return desc_->ops()->size();
}
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
}
int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx();
}
BlockDesc() = delete;
private:
proto::BlockDesc* desc_; // not_own
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Generated from framework.proto
namespace paddle.lite.fbs.proto;
enum AttrType : int {
INT = 0,
FLOAT = 1,
STRING = 2,
INTS = 3,
FLOATS = 4,
STRINGS = 5,
BOOLEAN = 6,
BOOLEANS = 7,
BLOCK = 8,
LONG = 9,
BLOCKS = 10,
LONGS = 11,
}
namespace paddle.lite.fbs.proto.VarType_;
enum Type : int {
BOOL = 0,
INT16 = 1,
INT32 = 2,
INT64 = 3,
FP16 = 4,
FP32 = 5,
FP64 = 6,
LOD_TENSOR = 7,
SELECTED_ROWS = 8,
FEED_MINIBATCH = 9,
FETCH_LIST = 10,
STEP_SCOPES = 11,
LOD_RANK_TABLE = 12,
LOD_TENSOR_ARRAY = 13,
PLACE_LIST = 14,
READER = 15,
RAW = 17,
TUPLE = 18,
SIZE_T = 19,
UINT8 = 20,
INT8 = 21,
}
namespace paddle.lite.fbs.proto.CompatibleInfo_;
enum Type : int {
COMPATIBLE = 0,
DEFINITELY_NOT = 1,
POSSIBLE = 2,
BUG_FIX = 3,
PRECISION_CHANGE = 4,
}
namespace paddle.lite.fbs.proto;
table Version {
version:long;
}
table OpDesc {
type:string (required);
inputs:[paddle.lite.fbs.proto.OpDesc_.Var];
outputs:[paddle.lite.fbs.proto.OpDesc_.Var];
attrs:[paddle.lite.fbs.proto.OpDesc_.Attr];
is_target:bool;
}
namespace paddle.lite.fbs.proto.OpDesc_;
table Attr {
name:string (required, key);
type:paddle.lite.fbs.proto.AttrType;
i:int;
f:float;
s:string;
ints:[int];
floats:[float];
strings:[string];
b:bool;
bools:[bool];
block_idx:int;
l:long;
blocks_idx:[int];
longs:[long];
}
table Var {
parameter:string (required, key);
arguments:[string];
}
namespace paddle.lite.fbs.proto;
table VarType {
type:paddle.lite.fbs.proto.VarType_.Type;
selected_rows:paddle.lite.fbs.proto.VarType_.TensorDesc;
lod_tensor:paddle.lite.fbs.proto.VarType_.LoDTensorDesc;
tensor_array:paddle.lite.fbs.proto.VarType_.LoDTensorArrayDesc;
reader:paddle.lite.fbs.proto.VarType_.ReaderDesc;
tuple:paddle.lite.fbs.proto.VarType_.Tuple;
}
namespace paddle.lite.fbs.proto.VarType_;
table TensorDesc {
data_type:paddle.lite.fbs.proto.VarType_.Type;
dims:[long];
}
table LoDTensorDesc {
tensor:paddle.lite.fbs.proto.VarType_.TensorDesc (required);
lod_level:int;
}
table LoDTensorArrayDesc {
tensor:paddle.lite.fbs.proto.VarType_.TensorDesc (required);
lod_level:int;
}
table ReaderDesc {
lod_tensor:[paddle.lite.fbs.proto.VarType_.LoDTensorDesc];
}
table Tuple {
element_type:[paddle.lite.fbs.proto.VarType_.Type];
}
namespace paddle.lite.fbs.proto;
table VarDesc {
name:string (required, key);
type:paddle.lite.fbs.proto.VarType (required);
persistable:bool;
need_check_feed:bool;
}
table BlockDesc {
idx:int;
parent_idx:int;
vars:[paddle.lite.fbs.proto.VarDesc];
ops:[paddle.lite.fbs.proto.OpDesc];
forward_block_idx:int = -1;
}
table CompatibleInfo {
version:string (required);
type:paddle.lite.fbs.proto.CompatibleInfo_.Type;
}
table OpCompatibleMap {
pair:[paddle.lite.fbs.proto.OpCompatibleMap_.OpCompatiblePair];
default_required_version:string;
}
namespace paddle.lite.fbs.proto.OpCompatibleMap_;
table OpCompatiblePair {
op_name:string (required, key);
compatible_info:paddle.lite.fbs.proto.CompatibleInfo (required);
}
namespace paddle.lite.fbs.proto;
table ProgramDesc {
blocks:[paddle.lite.fbs.proto.BlockDesc];
version:paddle.lite.fbs.proto.Version;
op_compatible_map:paddle.lite.fbs.proto.OpCompatibleMap;
}
root_type paddle.lite.fbs.proto.ProgramDesc;
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/op_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
template <>
std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
if (!it->s()) {
return std::string();
}
return it->s()->str();
}
template <>
std::string OpDesc::GetAttr<std::string>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
if (!it->s()) {
return std::string();
}
return it->s()->str();
}
template <>
std::vector<std::string> OpDesc::GetAttr<std::vector<std::string>>(
const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
CHECK(it) << "Attr " << name << "does not exist.";
std::vector<std::string> res;
if (it->strings()) {
res.reserve(it->strings()->size());
for (const auto& v : *it->strings()) {
res.push_back(v->str());
}
}
return res;
}
template <>
std::vector<std::string> OpDesc::GetAttr<std::vector<std::string>>(
size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
CHECK(it) << "Attr " << idx << "does not exist.";
std::vector<std::string> res;
if (it->strings()) {
res.reserve(it->strings()->size());
for (const auto& v : *it->strings()) {
res.push_back(v->str());
}
}
return res;
}
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return it->fb_f__(); \
} \
template <> \
T OpDesc::GetAttr<T>(size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
}
#define GET_ATTRS_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
T res; \
res.reserve(it->fb_f__()->size()); \
for (const auto& v : *it->fb_f__()) { \
res.push_back(v); \
} \
return res; \
} \
template <> \
T OpDesc::GetAttr<T>(size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
T res; \
res.reserve(it->fb_f__()->size()); \
for (const auto& v : *it->fb_f__()) { \
res.push_back(v); \
} \
return res; \
}
GET_ATTR_IMPL(int32_t, i);
GET_ATTR_IMPL(int16_t, block_idx);
GET_ATTR_IMPL(float, f);
GET_ATTR_IMPL(bool, b);
GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "lite/model_parser/base/op_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class OpDesc : public OpDescReadAPI {
public:
explicit OpDesc(proto::OpDesc* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); }
// Get the arguments of parameter called `param`
std::vector<std::string> Input(const std::string& param) const override {
const auto& var = desc_->inputs()->LookupByKey(param.c_str());
std::vector<std::string> args_vec;
if (var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& in : *var->arguments()) {
args_vec.push_back(in->str());
}
}
return args_vec;
}
std::vector<std::string> InputArgumentNames() const override {
const auto& vars = desc_->inputs();
std::vector<std::string> input_names_vec;
if (vars) {
input_names_vec.reserve(vars->size());
for (const auto& in : *vars) {
input_names_vec.push_back(in->parameter()->str());
}
}
return input_names_vec;
}
std::vector<std::string> Output(const std::string& param) const override {
const auto& var = desc_->outputs()->LookupByKey(param.c_str());
std::vector<std::string> args_vec;
if (var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& out : *var->arguments()) {
args_vec.push_back(out->str());
}
}
return args_vec;
}
std::vector<std::string> OutputArgumentNames() const override {
const auto& vars = desc_->outputs();
std::vector<std::string> output_names_vec;
if (vars) {
output_names_vec.reserve(vars->size());
for (const auto& out : *vars) {
output_names_vec.push_back(out->parameter()->str());
}
}
return output_names_vec;
}
bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) == nullptr;
}
size_t AttrsSize() const { return desc_->attrs()->size(); }
std::string AttrName(size_t idx) const {
return desc_->attrs()->Get(idx)->name()->str();
}
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
CHECK(attr);
return static_cast<OpDescAPI::AttrType>(attr->type());
}
OpDescAPI::AttrType GetAttrType(size_t idx) const {
const auto& attr = desc_->attrs()->Get(idx);
CHECK(attr);
return static_cast<OpDescAPI::AttrType>(attr->type());
}
std::vector<std::string> AttrNames() const override {
const auto& attrs = desc_->attrs();
std::vector<std::string> attr_names_vec;
if (attrs) {
attr_names_vec.reserve(attrs->size());
for (const auto& attr : *attrs) {
attr_names_vec.push_back(attr->name()->str());
}
}
return attr_names_vec;
}
template <typename T>
T GetAttr(const std::string& name) const;
template <typename T>
T GetAttr(size_t idx) const;
OpDesc() = delete;
private:
proto::OpDesc* desc_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
template <>
proto::BlockDesc* ProgramDesc::GetBlock<proto::BlockDesc>(int32_t idx) {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return const_cast<proto::BlockDesc*>(desc_->blocks()->Get(idx));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class ProgramDesc : public ProgramDescReadAPI {
public:
explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); }
size_t BlocksSize() const override { return desc_->blocks()->size(); }
template <typename T>
T *GetBlock(int32_t idx);
template <typename T>
T const *GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
}
bool HasVersion() const override { return desc_->version() != nullptr; }
int64_t Version() const override {
CHECK(HasVersion());
return desc_->version()->version();
}
private:
proto::ProgramDesc *desc_; // not_own
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/var_desc.h"
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "lite/model_parser/base/var_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class VarDesc : public VarDescReadAPI {
public:
explicit VarDesc(proto::VarDesc* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); }
VarDescAPI::Type GetType() const override {
return static_cast<VarDescAPI::Type>(desc_->type()->type());
}
bool Persistable() const override { return desc_->persistable(); }
std::vector<int64_t> GetShape() const override {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
const auto& dims = desc_->type()->lod_tensor()->tensor()->dims();
std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size());
for (const auto& dim : *dims) {
dims_vec.push_back(dim);
}
return dims_vec;
}
VarDesc() = delete;
private:
proto::VarDesc* desc_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
Subproject commit ac203b20926b13a35ff85277d2e5d3c38698eee8
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册