From d51324bfee1a21a2c0515c4d085d8e06ca0837d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Tue, 4 Aug 2020 16:01:43 +0800 Subject: [PATCH] new class: ParamDesc, test=develop (#4009) * add ParamDesc, test=develop * serialize tensor funcs, test=develop --- cmake/external/flatbuffers.cmake | 5 +- lite/api/CMakeLists.txt | 4 +- lite/api/android/jni/native/CMakeLists.txt | 4 +- lite/model_parser/base/apis.h | 1 + lite/model_parser/base/param_desc.h | 88 ++++++++ lite/model_parser/base/traits.h | 73 +++++++ lite/model_parser/base/var_desc.h | 32 +-- lite/model_parser/flatbuffers/CMakeLists.txt | 9 +- lite/model_parser/flatbuffers/io.cc | 39 ++++ lite/model_parser/flatbuffers/io.h | 15 ++ lite/model_parser/flatbuffers/op_desc.h | 5 +- lite/model_parser/flatbuffers/param.fbs | 37 ++++ lite/model_parser/flatbuffers/param_desc.cc | 15 ++ lite/model_parser/flatbuffers/param_desc.h | 216 +++++++++++++++++++ lite/model_parser/flatbuffers/traits.h | 144 +++++++++++++ lite/model_parser/flatbuffers/var_desc.h | 6 +- 16 files changed, 648 insertions(+), 45 deletions(-) create mode 100644 lite/model_parser/base/param_desc.h create mode 100644 lite/model_parser/flatbuffers/param.fbs create mode 100644 lite/model_parser/flatbuffers/param_desc.cc create mode 100644 lite/model_parser/flatbuffers/param_desc.h create mode 100644 lite/model_parser/flatbuffers/traits.h diff --git a/cmake/external/flatbuffers.cmake b/cmake/external/flatbuffers.cmake index e6ab31ee85..4c2413c620 100644 --- a/cmake/external/flatbuffers.cmake +++ b/cmake/external/flatbuffers.cmake @@ -100,7 +100,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT) ${OPT} -o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS}" - DEPENDS flatbuffers + DEPENDS flatbuffers ${SRC_FBS} COMMENT "Run generation: '${GEN_HEADER}'") register_generated_output(${GEN_HEADER}) add_custom_target(${TARGET} ALL DEPENDS ${GEN_HEADER}) @@ -108,7 +108,10 @@ endfunction() set(FRAMEWORK_FBS_DIR "lite/model_parser/flatbuffers") set(FRAMEWORK_SCHEMA_PATH "${FRAMEWORK_FBS_DIR}/framework.fbs") +set(PARAM_SCHEMA_PATH "${FRAMEWORK_FBS_DIR}/param.fbs") compile_flatbuffers_schema_to_cpp_opt(framework_fbs_header ${FRAMEWORK_SCHEMA_PATH} "--no-includes;--gen-compare;--force-empty") +compile_flatbuffers_schema_to_cpp_opt(param_fbs_header ${PARAM_SCHEMA_PATH} "--no-includes;--gen-compare;--force-empty") include_directories(${FLATBUFFERS_INCLUDE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}) +add_custom_target(fbs_headers ALL DEPENDS framework_fbs_header param_fbs_header) diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 6ff381268a..3bceda8717 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -16,7 +16,7 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR LITE_WITH lite_cc_library(paddle_full_api_shared SHARED SRCS paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc DEPS paddle_api paddle_api_light paddle_api_full) target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files}) - add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto op_registry framework_fbs_header) + add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto op_registry fbs_headers) target_link_libraries(paddle_full_api_shared framework_proto op_registry) if(LITE_WITH_X86) add_dependencies(paddle_full_api_shared xxhash) @@ -72,7 +72,7 @@ else() set(TARGET_COMIPILE_FLAGS "${TARGET_COMIPILE_FLAGS} -flto") endif() set_target_properties(paddle_light_api_shared PROPERTIES COMPILE_FLAGS "${TARGET_COMIPILE_FLAGS}") - add_dependencies(paddle_light_api_shared op_list_h kernel_list_h framework_fbs_header) + add_dependencies(paddle_light_api_shared op_list_h kernel_list_h fbs_headers) if (LITE_WITH_NPU) # Need to add HIAI runtime libs (libhiai.so) dependency target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs}) diff --git a/lite/api/android/jni/native/CMakeLists.txt b/lite/api/android/jni/native/CMakeLists.txt index 2929e24117..4638ed5fdf 100644 --- a/lite/api/android/jni/native/CMakeLists.txt +++ b/lite/api/android/jni/native/CMakeLists.txt @@ -17,7 +17,7 @@ if (NOT LITE_ON_TINY_PUBLISH) # Unlike static library, module library has to link target to be able to work # as a single .so lib. target_link_libraries(paddle_lite_jni ${lib_DEPS} ${arm_kernels} ${npu_kernels}) - add_dependencies(paddle_lite_jni framework_fbs_header) + add_dependencies(paddle_lite_jni fbs_headers) if (LITE_WITH_NPU) # Strips the symbols of our protobuf functions to fix the conflicts during # loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so) @@ -32,7 +32,7 @@ else() endif() set_target_properties(paddle_lite_jni PROPERTIES COMPILE_FLAGS ${TARGET_COMIPILE_FLAGS}) target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc) - add_dependencies(paddle_lite_jni op_list_h kernel_list_h framework_fbs_header) + add_dependencies(paddle_lite_jni op_list_h kernel_list_h fbs_headers) if (LITE_WITH_NPU) # Need to add HIAI runtime libs (libhiai.so) dependency target_link_libraries(paddle_lite_jni ${npu_builder_libs} ${npu_runtime_libs}) diff --git a/lite/model_parser/base/apis.h b/lite/model_parser/base/apis.h index fa3449017c..898604dda1 100644 --- a/lite/model_parser/base/apis.h +++ b/lite/model_parser/base/apis.h @@ -16,6 +16,7 @@ #include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/base/op_desc.h" +#include "lite/model_parser/base/param_desc.h" #include "lite/model_parser/base/program_desc.h" #include "lite/model_parser/base/proto_desc.h" #include "lite/model_parser/base/traits.h" diff --git a/lite/model_parser/base/param_desc.h b/lite/model_parser/base/param_desc.h new file mode 100644 index 0000000000..1c40ba3e89 --- /dev/null +++ b/lite/model_parser/base/param_desc.h @@ -0,0 +1,88 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/model_parser/base/traits.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +class ParamDescReadAPI { + public: + virtual std::string Name() const = 0; + virtual std::vector Dim() const = 0; + virtual VarDataType GetDataType() const = 0; + virtual const void *GetData() const = 0; + virtual size_t byte_size() const = 0; + + virtual ~ParamDescReadAPI() = default; +}; + +class ParamDescWriteAPI { + public: + virtual void SetName(const std::string &name) { NotImplemented(); } + virtual void SetDim(const std::vector &dim) { NotImplemented(); } + virtual void SetDataType(VarDataType data_type) { NotImplemented(); } + virtual void SetData(const void *data, size_t byte_size) { NotImplemented(); } + + virtual ~ParamDescWriteAPI() = default; + + private: + void NotImplemented() const { + LOG(FATAL) << "ParamDescWriteAPI is not available in model read-only mode."; + } +}; + +class CombinedParamsDescReadAPI { + public: + virtual const ParamDescReadAPI *GetParamDesc(size_t idx) const = 0; + virtual size_t GetParamsSize() const = 0; + virtual ~CombinedParamsDescReadAPI() = default; +}; + +class CombinedParamsDescWriteAPI { + public: + virtual ParamDescWriteAPI *AddParamDesc() { + NotImplemented(); + return nullptr; + } + virtual ~CombinedParamsDescWriteAPI() = default; + + private: + void NotImplemented() const { + LOG(FATAL) << "CombinedParamsDescWriteAPI is not available in model " + "read-only mode."; + } +}; + +// The reading and writing of the model are one-time and separate. +// This interface is a combination of reading and writing interfaces, +// which is used to support legacy interfaces. + +class ParamDescAPI : public ParamDescReadAPI, public ParamDescWriteAPI { + public: + virtual ~ParamDescAPI() = default; +}; + +class CombinedParamsDescAPI : public CombinedParamsDescReadAPI, + public CombinedParamsDescWriteAPI { + public: + virtual ~CombinedParamsDescAPI() = default; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/base/traits.h b/lite/model_parser/base/traits.h index bda293686c..09ac80ddc9 100644 --- a/lite/model_parser/base/traits.h +++ b/lite/model_parser/base/traits.h @@ -16,6 +16,8 @@ #include #include +#include "lite/api/paddle_place.h" +#include "lite/utils/cp_logging.h" namespace paddle { namespace lite { @@ -37,6 +39,77 @@ enum class OpAttrType { UNK, }; +enum class VarDataType { + // Pod Types + BOOL = 0, + INT16, + INT32, + INT64, + FP16, + FP32, + FP64, + // Tensor is used in C++. + SIZE_T, + UINT8, + INT8, + + // Other types that may need additional descriptions + LOD_TENSOR, + SELECTED_ROWS, + FEED_MINIBATCH, + FETCH_LIST, + STEP_SCOPES, + LOD_RANK_TABLE, + LOD_TENSOR_ARRAY, + PLACE_LIST, + READER, + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators like nccl_op + RAW, + TUPLE +}; + +inline VarDataType ConvertPrecisionType(lite_api::PrecisionType type) { +#define CASE(ptype, vtype) \ + case lite_api::PrecisionType::k##ptype: \ + return lite::VarDataType::vtype; \ + break + switch (type) { + CASE(Float, FP32); + CASE(Int8, INT8); + CASE(Int32, INT32); + CASE(FP16, FP16); + CASE(Bool, BOOL); + CASE(Int64, INT64); + CASE(Int16, INT16); + default: + LOG(FATAL) << "Illegal flatbuffer VarType."; + return lite::VarDataType(); + } +#undef CASE +} + +inline lite_api::PrecisionType ConvertPrecisionType(VarDataType type) { +#define CASE(ptype, vtype) \ + case lite::VarDataType::vtype: \ + return lite_api::PrecisionType::k##ptype; \ + break + switch (type) { + CASE(Float, FP32); + CASE(Int8, INT8); + CASE(Int32, INT32); + CASE(FP16, FP16); + CASE(Bool, BOOL); + CASE(Int64, INT64); + CASE(Int16, INT16); + default: + LOG(FATAL) << "Illegal flatbuffer VarType."; + return lite_api::PrecisionType(); + } +#undef CASE +} + struct Standard {}; struct Flatbuffers {}; diff --git a/lite/model_parser/base/var_desc.h b/lite/model_parser/base/var_desc.h index 47596f8792..fa5c89b8c7 100644 --- a/lite/model_parser/base/var_desc.h +++ b/lite/model_parser/base/var_desc.h @@ -16,42 +16,12 @@ #include #include +#include "lite/model_parser/base/traits.h" #include "lite/utils/cp_logging.h" namespace paddle { namespace lite { -enum class VarDataType { - // Pod Types - BOOL = 0, - INT16, - INT32, - INT64, - FP16, - FP32, - FP64, - // Tensor is used in C++. - SIZE_T, - UINT8, - INT8, - - // Other types that may need additional descriptions - LOD_TENSOR, - SELECTED_ROWS, - FEED_MINIBATCH, - FETCH_LIST, - STEP_SCOPES, - LOD_RANK_TABLE, - LOD_TENSOR_ARRAY, - PLACE_LIST, - READER, - // Any runtime decided variable type is raw - // raw variables should manage their own allocations - // in operators like nccl_op - RAW, - TUPLE -}; - class VarDescReadAPI { public: virtual std::string Name() const = 0; diff --git a/lite/model_parser/flatbuffers/CMakeLists.txt b/lite/model_parser/flatbuffers/CMakeLists.txt index b7ae9514ef..66723808ad 100644 --- a/lite/model_parser/flatbuffers/CMakeLists.txt +++ b/lite/model_parser/flatbuffers/CMakeLists.txt @@ -5,9 +5,10 @@ function(lite_fbs_library TARGET) add_dependencies(${TARGET} ${args_FBS_DEPS}) endfunction() -lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header) -lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header) -lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header) +lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS fbs_headers) +lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS fbs_headers) +lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS fbs_headers) lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc) -lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc) +lite_fbs_library(fbs_param_desc SRCS param_desc.cc FBS_DEPS fbs_headers) +lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc) lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc) diff --git a/lite/model_parser/flatbuffers/io.cc b/lite/model_parser/flatbuffers/io.cc index ef8e9afaef..b4a6f661eb 100644 --- a/lite/model_parser/flatbuffers/io.cc +++ b/lite/model_parser/flatbuffers/io.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "lite/model_parser/flatbuffers/io.h" +#include #include #include #include +#include "lite/model_parser/flatbuffers/traits.h" namespace paddle { namespace lite { @@ -33,6 +35,43 @@ void LoadModel(const std::string& path, ProgramDesc* prog) { prog->Init(std::move(buf)); } +void SetParamWithTensor(const std::string& name, + const lite::Tensor& tensor, + ParamDescWriteAPI* prog) { + CHECK(prog); + prog->SetName(name); + prog->SetDim(tensor.dims().Vectorize()); + prog->SetDataType(lite::ConvertPrecisionType(tensor.precision())); + prog->SetData(tensor.raw_data(), tensor.memory_size()); +} + +void SetTensorWithParam(lite::Tensor* tensor, const ParamDescReadAPI& param) { + tensor->Resize(param.Dim()); + tensor->set_precision(lite::ConvertPrecisionType(param.GetDataType())); + std::memcpy(tensor->mutable_data(param.byte_size()), + param.GetData(), + param.byte_size()); +} + +void SetCombinedParamsWithScope(const lite::Scope& scope, + const std::vector& params_name, + CombinedParamsDescWriteAPI* params) { + for (const auto& name : params_name) { + auto* param = params->AddParamDesc(); + auto& tensor = scope.FindVar(name)->Get(); + SetParamWithTensor(name, tensor, param); + } +} + +void SetScopeWithCombinedParams(lite::Scope* scope, + const CombinedParamsDescReadAPI& params) { + CHECK(scope); + for (size_t i = 0; i < params.GetParamsSize(); ++i) { + const auto& param = *params.GetParamDesc(i); + auto* tensor = scope->Var(param.Name())->GetMutable(); + SetTensorWithParam(tensor, param); + } +} } // namespace fbs } // namespace lite } // namespace paddle diff --git a/lite/model_parser/flatbuffers/io.h b/lite/model_parser/flatbuffers/io.h index 1c81b192bb..9a46ed42bb 100644 --- a/lite/model_parser/flatbuffers/io.h +++ b/lite/model_parser/flatbuffers/io.h @@ -15,6 +15,10 @@ #pragma once #include +#include +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/model_parser/flatbuffers/param_desc.h" #include "lite/model_parser/flatbuffers/program_desc.h" namespace paddle { @@ -23,6 +27,17 @@ namespace fbs { void LoadModel(const std::string& path, ProgramDesc* prog); +void SetParamWithTensor(const std::string& name, + const lite::Tensor& tensor, + ParamDescWriteAPI* prog); +void SetTensorWithParam(const lite::Tensor& tensor, ParamDescReadAPI* prog); + +void SetCombinedParamsWithScope(const lite::Scope& scope, + const std::vector& params_name, + CombinedParamsDescWriteAPI* params); +void SetScopeWithCombinedParams(lite::Scope* scope, + const CombinedParamsDescReadAPI& params); + } // namespace fbs } // namespace lite } // namespace paddle diff --git a/lite/model_parser/flatbuffers/op_desc.h b/lite/model_parser/flatbuffers/op_desc.h index 450aa49fa1..f6e4ab81e6 100644 --- a/lite/model_parser/flatbuffers/op_desc.h +++ b/lite/model_parser/flatbuffers/op_desc.h @@ -21,6 +21,7 @@ #include "lite/model_parser/base/op_desc.h" #include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/model_parser/flatbuffers/traits.h" #include "lite/model_parser/flatbuffers/vector_view.h" #include "lite/utils/all.h" @@ -96,13 +97,13 @@ class OpDesc : public OpDescAPI { OpDescAPI::AttrType GetAttrType(const std::string& name) const override { const auto& attr = desc_->attrs()->LookupByKey(name.c_str()); CHECK(attr) << "Can not find attr: " << name; - return static_cast(attr->type()); + return ConvertAttrType(attr->type()); } OpDescAPI::AttrType GetAttrType(size_t idx) const { const auto& attr = desc_->attrs()->Get(idx); CHECK(attr); - return static_cast(attr->type()); + return ConvertAttrType(attr->type()); } std::vector AttrNames() const override { diff --git a/lite/model_parser/flatbuffers/param.fbs b/lite/model_parser/flatbuffers/param.fbs new file mode 100644 index 0000000000..94437a8880 --- /dev/null +++ b/lite/model_parser/flatbuffers/param.fbs @@ -0,0 +1,37 @@ +include "framework.fbs"; + +namespace paddle.lite.fbs.proto; + +table CombinedParamsDesc { + params:[paddle.lite.fbs.proto.ParamDesc]; +} + +namespace paddle.lite.fbs.proto.ParamDesc_; + +table LoDTensorDesc { + lod_level:int; + lod:[long]; + dim:[long]; + data_type:paddle.lite.fbs.proto.VarType_.Type; + data:[byte]; +} + +table VersionDesc { + version:int; + model_version:int; +} + +union VariableDesc { + LoDTensorDesc +} + +namespace paddle.lite.fbs.proto; + +table ParamDesc { + version:paddle.lite.fbs.proto.ParamDesc_.VersionDesc; + name:string (required, key); + variable:paddle.lite.fbs.proto.ParamDesc_.VariableDesc; +} + +root_type paddle.lite.fbs.proto.ParamDesc; +root_type paddle.lite.fbs.proto.CombinedParamsDesc; diff --git a/lite/model_parser/flatbuffers/param_desc.cc b/lite/model_parser/flatbuffers/param_desc.cc new file mode 100644 index 0000000000..b69b2fd9d9 --- /dev/null +++ b/lite/model_parser/flatbuffers/param_desc.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/flatbuffers/param_desc.h" diff --git a/lite/model_parser/flatbuffers/param_desc.h b/lite/model_parser/flatbuffers/param_desc.h new file mode 100644 index 0000000000..e23c91fdc5 --- /dev/null +++ b/lite/model_parser/flatbuffers/param_desc.h @@ -0,0 +1,216 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/model_parser/base/param_desc.h" +#include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/model_parser/flatbuffers/param_generated.h" +#include "lite/model_parser/flatbuffers/traits.h" + +namespace paddle { +namespace lite { +namespace fbs { + +class ParamDescView : public ParamDescReadAPI { + public: + explicit ParamDescView(proto::ParamDesc const* desc) : desc_(desc) { + CHECK(desc_); + CHECK(desc_->variable_type() == + proto::ParamDesc_::VariableDesc_LoDTensorDesc); + tensor_desc_ = desc_->variable_as(); + } + std::string Name() const override { return desc_->name()->c_str(); } + + std::vector Dim() const override { + const auto& dims = tensor_desc_->dim(); + std::vector dims_vec; + dims_vec.reserve(dims->size()); + for (const auto& dim : *dims) { + dims_vec.push_back(dim); + } + return dims_vec; + } + + VarDataType GetDataType() const override { + return ConvertVarType(tensor_desc_->data_type()); + } + + const void* GetData() const override { return tensor_desc_->data()->Data(); } + + size_t byte_size() const override { return tensor_desc_->data()->size(); } + + ParamDescView() = delete; + + private: + proto::ParamDesc const* desc_; + proto::ParamDesc_::LoDTensorDesc const* tensor_desc_; +}; + +class CombinedParamsDescView : public CombinedParamsDescReadAPI { + public: + CombinedParamsDescView() = default; + explicit CombinedParamsDescView(const std::vector& buf) { Init(buf); } + explicit CombinedParamsDescView(std::vector&& buf) { + Init(std::forward>(buf)); + } + + void Init(const std::vector& buf) { + CHECK(buf.data()); + buf_ = buf; + InitParams(); + } + + void Init(std::vector&& buf) { + CHECK(buf.data()); + buf_ = std::move(buf); + InitParams(); + } + + void InitParams() { + desc_ = proto::GetCombinedParamsDesc(buf_.data()); + params_.reserve(GetParamsSize()); + for (size_t idx = 0; idx < GetParamsSize(); ++idx) { + params_.push_back(ParamDescView(desc_->params()->Get(idx))); + } + } + + const ParamDescReadAPI* GetParamDesc(size_t idx) const override { + CHECK(idx < GetParamsSize()); + return ¶ms_[idx]; + } + + size_t GetParamsSize() const override { return params_.size(); } + + private: + std::vector params_; + std::vector buf_; + proto::CombinedParamsDesc const* desc_; +}; + +class ParamDesc : public ParamDescAPI { + public: + ParamDesc() : owned_(true), desc_(new proto::ParamDescT()) { + desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT()); + lod_tensor_ = desc_->variable.AsLoDTensorDesc(); + CHECK(lod_tensor_); + } + + explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) { + lod_tensor_ = desc_->variable.AsLoDTensorDesc(); + CHECK(lod_tensor_); + } + + std::string Name() const override { return desc_->name; } + void SetName(const std::string& name) override { desc_->name = name; } + + std::vector Dim() const override { return lod_tensor_->dim; } + void SetDim(const std::vector& dim) override { + lod_tensor_->dim = dim; + } + + VarDataType GetDataType() const override { + return ConvertVarType(lod_tensor_->data_type); + } + void SetDataType(VarDataType data_type) override { + lod_tensor_->data_type = ConvertVarType(data_type); + } + + const void* GetData() const override { return lod_tensor_->data.data(); } + + size_t byte_size() const override { return lod_tensor_->data.size(); } + + void SetData(const void* data, size_t byte_size) { + lod_tensor_->data.resize(byte_size); + std::memcpy(lod_tensor_->data.data(), data, byte_size); + } + + const proto::ParamDescT* raw_desc() const { return desc_; } + + ~ParamDesc() { + if (owned_) { + delete desc_; + } + } + + private: + bool owned_{false}; + proto::ParamDescT* desc_{nullptr}; + proto::ParamDesc_::LoDTensorDescT* lod_tensor_{nullptr}; +}; + +class CombinedParamsDesc : public CombinedParamsDescAPI { + public: + CombinedParamsDesc() = default; + + explicit CombinedParamsDesc(const std::vector& buf) { + const auto* raw_buf = proto::GetCombinedParamsDesc(buf.data()); + raw_buf->UnPackTo(&desc_); + SyncParams(); + } + const ParamDescReadAPI* GetParamDesc(size_t idx) const override { + return ¶ms_[idx]; + } + + size_t GetParamsSize() const override { return desc_.params.size(); } + + ParamDescWriteAPI* AddParamDesc() override { + desc_.params.push_back(std::unique_ptr()); + SyncParams(); + return ¶ms_[params_.size() - 1]; + } + + const void* data() { + SyncBuffer(); + return buf_.data(); + } + + size_t buf_size() { + SyncBuffer(); + return buf_.size(); + } + + private: + void SyncParams() { + params_.resize(GetParamsSize()); + for (size_t i = 0; i < GetParamsSize(); ++i) { + if (params_[i].raw_desc() != desc_.params[i].get()) { + params_[i] = ParamDesc(desc_.params[i].get()); + } + } + } + + void SyncBuffer() { + fbb_.Reset(); + flatbuffers::Offset desc = + proto::CombinedParamsDesc::Pack(fbb_, &desc_); + fbb_.Finish(desc); + buf_ = fbb_.Release(); + } + + flatbuffers::DetachedBuffer buf_; + flatbuffers::FlatBufferBuilder fbb_; + proto::CombinedParamsDescT desc_; + std::vector params_; +}; + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/traits.h b/lite/model_parser/flatbuffers/traits.h new file mode 100644 index 0000000000..f8447926d1 --- /dev/null +++ b/lite/model_parser/flatbuffers/traits.h @@ -0,0 +1,144 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/model_parser/base/traits.h" +#include "lite/model_parser/flatbuffers/framework_generated.h" + +namespace paddle { +namespace lite { +namespace fbs { + +inline lite::VarDataType ConvertVarType(proto::VarType_::Type type) { +#define CASE(type) \ + case proto::VarType_::Type_##type: \ + return lite::VarDataType::type; \ + break + switch (type) { + CASE(BOOL); + CASE(INT16); + CASE(INT32); + CASE(INT64); + CASE(FP16); + CASE(FP32); + CASE(FP64); + CASE(LOD_TENSOR); + CASE(SELECTED_ROWS); + CASE(FEED_MINIBATCH); + CASE(FETCH_LIST); + CASE(STEP_SCOPES); + CASE(LOD_RANK_TABLE); + CASE(LOD_TENSOR_ARRAY); + CASE(PLACE_LIST); + CASE(READER); + CASE(RAW); + CASE(TUPLE); + CASE(SIZE_T); + CASE(UINT8); + CASE(INT8); + default: + LOG(FATAL) << "Illegal flatbuffer VarType."; + return lite::VarDataType(); + } +#undef CASE +} + +inline proto::VarType_::Type ConvertVarType(lite::VarDataType type) { +#define CASE(type) \ + case lite::VarDataType::type: \ + return proto::VarType_::Type_##type; \ + break + switch (type) { + CASE(BOOL); + CASE(INT16); + CASE(INT32); + CASE(INT64); + CASE(FP16); + CASE(FP32); + CASE(FP64); + CASE(LOD_TENSOR); + CASE(SELECTED_ROWS); + CASE(FEED_MINIBATCH); + CASE(FETCH_LIST); + CASE(STEP_SCOPES); + CASE(LOD_RANK_TABLE); + CASE(LOD_TENSOR_ARRAY); + CASE(PLACE_LIST); + CASE(READER); + CASE(RAW); + CASE(TUPLE); + CASE(SIZE_T); + CASE(UINT8); + CASE(INT8); + default: + LOG(FATAL) << "Illegal flatbuffer VarType."; + return proto::VarType_::Type(); + } +#undef CASE +} + +inline lite::OpAttrType ConvertAttrType(proto::AttrType type) { +#define CASE(type) \ + case proto::AttrType_##type: \ + return lite::OpAttrType::type; \ + break + switch (type) { + CASE(INT); + CASE(FLOAT); + CASE(STRING); + CASE(INTS); + CASE(FLOATS); + CASE(STRINGS); + CASE(BOOLEAN); + CASE(BOOLEANS); + CASE(BLOCK); + CASE(LONG); + CASE(BLOCKS); + CASE(LONGS); + default: + LOG(FATAL) << "Illegal flatbuffer AttrType."; + return lite::OpAttrType(); + } +#undef CASE +} + +inline proto::AttrType ConvertAttrType(lite::OpAttrType type) { +#define CASE(type) \ + case lite::OpAttrType::type: \ + return proto::AttrType_##type; \ + break + switch (type) { + CASE(INT); + CASE(FLOAT); + CASE(STRING); + CASE(INTS); + CASE(FLOATS); + CASE(STRINGS); + CASE(BOOLEAN); + CASE(BOOLEANS); + CASE(BLOCK); + CASE(LONG); + CASE(BLOCKS); + CASE(LONGS); + default: + LOG(FATAL) << "Illegal flatbuffer AttrType."; + return proto::AttrType(); + } +#undef CASE +} + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/var_desc.h b/lite/model_parser/flatbuffers/var_desc.h index 48d81df30f..bbc5f3c40c 100644 --- a/lite/model_parser/flatbuffers/var_desc.h +++ b/lite/model_parser/flatbuffers/var_desc.h @@ -19,6 +19,7 @@ #include #include "lite/model_parser/base/var_desc.h" #include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/model_parser/flatbuffers/traits.h" #include "lite/utils/all.h" namespace paddle { @@ -32,7 +33,7 @@ class VarDesc : public VarDescAPI { std::string Name() const override { return desc_->name()->str(); } VarDescAPI::Type GetType() const override { - return static_cast(desc_->type()->type()); + return ConvertVarType(desc_->type()->type()); } bool Persistable() const override { return desc_->persistable(); } @@ -50,8 +51,7 @@ class VarDesc : public VarDescAPI { VarDescAPI::Type GetDataType() const { CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR); - return static_cast( - desc_->type()->lod_tensor()->tensor()->data_type()); + return ConvertVarType(desc_->type()->lod_tensor()->tensor()->data_type()); } private: -- GitLab