未验证 提交 0ac48fcb 编写于 作者: 石晓伟 提交者: GitHub

Add the conversion of cpp and flatbuffers program, test=develop (#4079)

* update compatible_pb.cc, test=develop

* using unique_ptr in fbs::desc, test=develop

* add fbs to compatible_pb, test=develop

* update model_parser functions, test=develop

* fix bugs, test=develop
上级 9b9245d9
...@@ -19,9 +19,9 @@ endif() ...@@ -19,9 +19,9 @@ endif()
if (NOT LITE_ON_TINY_PUBLISH) if (NOT LITE_ON_TINY_PUBLISH)
lite_cc_library(compatible_pb SRCS compatible_pb.cc lite_cc_library(compatible_pb SRCS compatible_pb.cc
DEPS ${cpp_wrapper} ${naive_wrapper} ${pb_wrapper} framework_proto) DEPS ${cpp_wrapper} ${naive_wrapper} ${pb_wrapper} framework_proto fbs_io)
else() else()
lite_cc_library(compatible_pb SRCS compatible_pb.cc DEPS ${cpp_wrapper} ${naive_wrapper}) lite_cc_library(compatible_pb SRCS compatible_pb.cc DEPS ${cpp_wrapper} ${naive_wrapper} fbs_io)
endif() endif()
lite_cc_library(model_parser SRCS model_parser.cc DEPS lite_cc_library(model_parser SRCS model_parser.cc DEPS
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "lite/model_parser/compatible_pb.h" #include "lite/model_parser/compatible_pb.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/model_parser/flatbuffers/program_desc.h"
#include "lite/model_parser/naive_buffer/block_desc.h" #include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h" #include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h"
...@@ -73,6 +74,18 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>( ...@@ -73,6 +74,18 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
}*/ }*/
} }
template <>
void TransformVarDescAnyToCpp<fbs::VarDesc>(const fbs::VarDesc &any_desc,
cpp::VarDesc *cpp_desc) {
cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable());
if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}
}
/// For OpDesc transform /// For OpDesc transform
template <typename OpDescType> template <typename OpDescType>
void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) {
...@@ -219,100 +232,102 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { ...@@ -219,100 +232,102 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
} }
/// For BlockDesc transform /// For BlockDesc transform
#define TRANS_BLOCK_ANY_WITH_CPP_IMPL(T, NT, PNT) \ #define TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpT, VarT, NT, PNT) \
template <> \ template <> \
void TransformBlockDescAnyToCpp<NT::T>(const NT::T &any_desc, \ void TransformBlockDescAnyToCpp<NT::BlockDesc>( \
cpp::BlockDesc *cpp_desc) { \ const NT::BlockDesc &any_desc, cpp::BlockDesc *cpp_desc) { \
NT::T desc = any_desc; \ NT::BlockDesc &desc = const_cast<NT::BlockDesc &>(any_desc); \
cpp_desc->SetIdx(desc.Idx()); \ cpp_desc->SetIdx(desc.Idx()); \
cpp_desc->SetParentIdx(desc.ParentIdx()); \ cpp_desc->SetParentIdx(desc.ParentIdx()); \
cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \
\ \
cpp_desc->ClearOps(); \ cpp_desc->ClearOps(); \
for (size_t i = 0; i < desc.OpsSize(); ++i) { \ for (size_t i = 0; i < desc.OpsSize(); ++i) { \
auto any_op_desc = NT::OpDesc(desc.GetOp<PNT::proto::OpDesc>(i)); \ auto any_op_desc = NT::OpDesc(desc.GetOp<PNT::proto::OpT>(i)); \
auto *cpp_op_desc = cpp_desc->AddOp<cpp::OpDesc>(); \ auto *cpp_op_desc = cpp_desc->AddOp<cpp::OpDesc>(); \
TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \ TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \
} \ } \
\ \
cpp_desc->ClearVars(); \ cpp_desc->ClearVars(); \
for (size_t i = 0; i < desc.VarsSize(); ++i) { \ for (size_t i = 0; i < desc.VarsSize(); ++i) { \
auto any_var_desc = NT::VarDesc(desc.GetVar<PNT::proto::VarDesc>(i)); \ auto any_var_desc = NT::VarDesc(desc.GetVar<PNT::proto::VarT>(i)); \
auto *cpp_var_desc = cpp_desc->AddVar<cpp::VarDesc>(); \ auto *cpp_var_desc = cpp_desc->AddVar<cpp::VarDesc>(); \
TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \ TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \
} \ } \
} \ } \
\ \
template <> \ template <> \
void TransformBlockDescCppToAny<NT::T>(const cpp::T &cpp_desc, \ void TransformBlockDescCppToAny<NT::BlockDesc>( \
NT::T *any_desc) { \ const cpp::BlockDesc &cpp_desc, NT::BlockDesc *any_desc) { \
const cpp::T &desc = cpp_desc; \ const cpp::BlockDesc &desc = cpp_desc; \
any_desc->SetIdx(desc.Idx()); \ any_desc->SetIdx(desc.Idx()); \
any_desc->SetParentIdx(desc.ParentIdx()); \ any_desc->SetParentIdx(desc.ParentIdx()); \
any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \
\ \
any_desc->ClearOps(); \ any_desc->ClearOps(); \
for (size_t i = 0; i < desc.OpsSize(); ++i) { \ for (size_t i = 0; i < desc.OpsSize(); ++i) { \
auto *cpp_op_desc = desc.GetOp<cpp::OpDesc>(i); \ auto *cpp_op_desc = desc.GetOp<cpp::OpDesc>(i); \
auto any_op_desc = NT::OpDesc(any_desc->AddOp<PNT::proto::OpDesc>()); \ auto any_op_desc = NT::OpDesc(any_desc->AddOp<PNT::proto::OpT>()); \
TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \ TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \
} \ } \
\ \
any_desc->ClearVars(); \ any_desc->ClearVars(); \
for (size_t i = 0; i < desc.VarsSize(); ++i) { \ for (size_t i = 0; i < desc.VarsSize(); ++i) { \
auto *cpp_var_desc = desc.GetVar<cpp::VarDesc>(i); \ auto *cpp_var_desc = desc.GetVar<cpp::VarDesc>(i); \
auto any_var_desc = \ auto any_var_desc = NT::VarDesc(any_desc->AddVar<PNT::proto::VarT>()); \
NT::VarDesc(any_desc->AddVar<PNT::proto::VarDesc>()); \ TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \
TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \ } \
} \
} }
/// For ProgramDesc transform /// For ProgramDesc transform
#define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(T, NT, PNT) \ #define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockT, NT, PNT) \
template <> \ template <> \
void TransformProgramDescAnyToCpp<NT::T>(const NT::T &any_desc, \ void TransformProgramDescAnyToCpp<NT::ProgramDesc>( \
cpp::ProgramDesc *cpp_desc) { \ const NT::ProgramDesc &any_desc, cpp::ProgramDesc *cpp_desc) { \
NT::T desc = any_desc; \ NT::ProgramDesc &desc = const_cast<NT::ProgramDesc &>(any_desc); \
if (desc.HasVersion()) { \ if (desc.HasVersion()) { \
cpp_desc->SetVersion(desc.Version()); \ cpp_desc->SetVersion(desc.Version()); \
} \ } \
\ \
cpp_desc->ClearBlocks(); \ cpp_desc->ClearBlocks(); \
for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ for (size_t i = 0; i < desc.BlocksSize(); ++i) { \
auto any_block_desc = \ NT::BlockDesc any_block_desc(desc.GetBlock<PNT::proto::BlockT>(i)); \
NT::BlockDesc(desc.GetBlock<PNT::proto::BlockDesc>(i)); \ auto *cpp_block_desc = cpp_desc->AddBlock<cpp::BlockDesc>(); \
auto *cpp_block_desc = cpp_desc->AddBlock<cpp::BlockDesc>(); \ TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \
TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \ } \
} \ } \
} \ \
\ template <> \
template <> \ void TransformProgramDescCppToAny<NT::ProgramDesc>( \
void TransformProgramDescCppToAny<NT::T>(const cpp::T &cpp_desc, \ const cpp::ProgramDesc &cpp_desc, NT::ProgramDesc *any_desc) { \
NT::T *any_desc) { \ auto &desc = cpp_desc; \
auto &desc = cpp_desc; \ if (desc.HasVersion()) { \
if (desc.HasVersion()) { \ any_desc->SetVersion(desc.Version()); \
any_desc->SetVersion(desc.Version()); \ } \
} \ \
\ any_desc->ClearBlocks(); \
any_desc->ClearBlocks(); \ for (size_t i = 0; i < desc.BlocksSize(); ++i) { \
for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ auto *cpp_block_desc = desc.GetBlock<cpp::BlockDesc>(i); \
auto *cpp_block_desc = desc.GetBlock<cpp::BlockDesc>(i); \ NT::BlockDesc any_block_desc(any_desc->AddBlock<PNT::proto::BlockT>()); \
auto any_block_desc = \ TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \
NT::BlockDesc(any_desc->AddBlock<PNT::proto::BlockDesc>()); \ } \
TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \
} \
} }
TRANS_VAR_ANY_WITH_CPP_IMPL(naive_buffer::VarDesc); TRANS_VAR_ANY_WITH_CPP_IMPL(naive_buffer::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(naive_buffer::OpDesc); TRANS_OP_ANY_WITH_CPP_IMPL(naive_buffer::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, naive_buffer, naive_buffer); TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDesc, VarDesc, naive_buffer, naive_buffer);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, naive_buffer, naive_buffer); TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDesc, naive_buffer, naive_buffer);
TRANS_VAR_ANY_WITH_CPP_IMPL(fbs::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(fbs::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDescT, VarDescT, fbs, fbs);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDescT, fbs, fbs);
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
TRANS_VAR_ANY_WITH_CPP_IMPL(pb::VarDesc); TRANS_VAR_ANY_WITH_CPP_IMPL(pb::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(pb::OpDesc); TRANS_OP_ANY_WITH_CPP_IMPL(pb::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, pb, framework); TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDesc, VarDesc, pb, framework);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, pb, framework); TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDesc, pb, framework);
#endif #endif
#undef TRANS_VAR_ANY_WITH_CPP_IMPL #undef TRANS_VAR_ANY_WITH_CPP_IMPL
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "lite/model_parser/compatible_pb.h" #include "lite/model_parser/compatible_pb.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "lite/model_parser/cpp_desc.h" #include "lite/model_parser/cpp_desc.h"
#include "lite/model_parser/flatbuffers/program_desc.h"
#include "lite/model_parser/flatbuffers/test_helper.h"
#include "lite/model_parser/naive_buffer/block_desc.h" #include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h" #include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h"
...@@ -430,5 +432,14 @@ TEST(ProgramDesc, AnyToCpp) { ...@@ -430,5 +432,14 @@ TEST(ProgramDesc, AnyToCpp) {
TestProgramAnyToCpp<naive_buffer::ProgramDesc>(&nb_desc); TestProgramAnyToCpp<naive_buffer::ProgramDesc>(&nb_desc);
} }
TEST(ProgramDesc, FbsCpp) {
fbs::ProgramDesc fbs_program(fbs::test::GenerateProgramCache());
cpp::ProgramDesc cpp_program;
TransformProgramDescAnyToCpp(fbs_program, &cpp_program);
fbs::ProgramDesc fbs_program_2;
TransformProgramDescCppToAny(cpp_program, &fbs_program_2);
fbs::test::CheckProgramCache(&fbs_program_2);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -21,52 +21,52 @@ namespace fbs { ...@@ -21,52 +21,52 @@ namespace fbs {
template <> template <>
proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const { proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return desc_->vars()->Get(idx); return desc_->vars()->Get(idx);
} }
template <> template <>
proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const { proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
return desc_->ops()->Get(idx); return desc_->ops()->Get(idx);
} }
template <> template <>
VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const { VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return &vars_[idx]; return &vars_[idx];
} }
template <> template <>
OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const { OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
return &ops_[idx]; return &ops_[idx];
} }
template <> template <>
proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) { proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return vars_[idx].raw_desc(); return vars_[idx]->raw_desc();
} }
template <> template <>
proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() { proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() {
desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT)); desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT));
SyncVars(); SyncVars();
return vars_.back().raw_desc(); return vars_.back()->raw_desc();
} }
template <> template <>
proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) { proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= vars.size()";
return ops_[idx].raw_desc(); return ops_[idx]->raw_desc();
} }
template <> template <>
proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() { proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() {
desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT)); desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT));
SyncOps(); SyncOps();
return ops_.back().raw_desc(); return ops_.back()->raw_desc();
} }
} // namespace fbs } // namespace fbs
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <vector> #include <vector>
#include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h" #include "lite/model_parser/flatbuffers/framework_generated.h"
...@@ -150,24 +151,24 @@ class BlockDesc : public BlockDescAPI { ...@@ -150,24 +151,24 @@ class BlockDesc : public BlockDescAPI {
void SyncVars() { void SyncVars() {
vars_.resize(desc_->vars.size()); vars_.resize(desc_->vars.size());
for (size_t i = 0; i < desc_->vars.size(); ++i) { for (size_t i = 0; i < desc_->vars.size(); ++i) {
if (vars_[i].raw_desc() != desc_->vars[i].get()) { if (!vars_[i] || vars_[i]->raw_desc() != desc_->vars[i].get()) {
vars_[i] = VarDesc(desc_->vars[i].get()); vars_[i].reset(new VarDesc(desc_->vars[i].get()));
} }
} }
} }
void SyncOps() { void SyncOps() {
ops_.resize(desc_->ops.size()); ops_.resize(desc_->ops.size());
for (size_t i = 0; i < desc_->ops.size(); ++i) { for (size_t i = 0; i < desc_->ops.size(); ++i) {
if (ops_[i].raw_desc() != desc_->ops[i].get()) { if (!ops_[i] || ops_[i]->raw_desc() != desc_->ops[i].get()) {
ops_[i] = OpDesc(desc_->ops[i].get()); ops_[i].reset(new OpDesc(desc_->ops[i].get()));
} }
} }
} }
bool owned_{false}; bool owned_{false};
proto::BlockDescT* desc_{nullptr}; proto::BlockDescT* desc_{nullptr};
std::vector<VarDesc> vars_; std::vector<std::unique_ptr<VarDesc>> vars_;
std::vector<OpDesc> ops_; std::vector<std::unique_ptr<OpDesc>> ops_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -25,11 +25,12 @@ namespace fbs { ...@@ -25,11 +25,12 @@ namespace fbs {
std::vector<char> LoadFile(const std::string& path) { std::vector<char> LoadFile(const std::string& path) {
FILE* file = fopen(path.c_str(), "rb"); FILE* file = fopen(path.c_str(), "rb");
CHECK(file);
fseek(file, 0, SEEK_END); fseek(file, 0, SEEK_END);
int64_t length = ftell(file); uint64_t length = ftell(file);
rewind(file); rewind(file);
std::vector<char> buf(length); std::vector<char> buf(length);
CHECK(fread(buf.data(), 1, length, file) == length); CHECK_EQ(fread(buf.data(), 1, length, file), length);
fclose(file); fclose(file);
return buf; return buf;
} }
...@@ -37,6 +38,7 @@ std::vector<char> LoadFile(const std::string& path) { ...@@ -37,6 +38,7 @@ std::vector<char> LoadFile(const std::string& path) {
void SaveFile(const std::string& path, const void* src, size_t byte_size) { void SaveFile(const std::string& path, const void* src, size_t byte_size) {
CHECK(src); CHECK(src);
FILE* file = fopen(path.c_str(), "wb"); FILE* file = fopen(path.c_str(), "wb");
CHECK(file);
CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size); CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size);
fclose(file); fclose(file);
} }
...@@ -60,7 +62,7 @@ void SetTensorWithParam(lite::Tensor* tensor, const ParamDescReadAPI& param) { ...@@ -60,7 +62,7 @@ void SetTensorWithParam(lite::Tensor* tensor, const ParamDescReadAPI& param) {
} }
void SetCombinedParamsWithScope(const lite::Scope& scope, void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name, const std::set<std::string>& params_name,
CombinedParamsDescWriteAPI* params) { CombinedParamsDescWriteAPI* params) {
for (const auto& name : params_name) { for (const auto& name : params_name) {
auto* param = params->AddParamDesc(); auto* param = params->AddParamDesc();
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/scope.h" #include "lite/core/scope.h"
...@@ -30,8 +31,9 @@ void SaveFile(const std::string& path, const void* src, size_t byte_size); ...@@ -30,8 +31,9 @@ void SaveFile(const std::string& path, const void* src, size_t byte_size);
void SetScopeWithCombinedParams(lite::Scope* scope, void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params); const CombinedParamsDescReadAPI& params);
void SetCombinedParamsWithScope(const lite::Scope& scope, void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name, const std::set<std::string>& params_name,
CombinedParamsDescWriteAPI* params); CombinedParamsDescWriteAPI* params);
} // namespace fbs } // namespace fbs
......
...@@ -32,7 +32,7 @@ void set_tensor(paddle::lite::Tensor* tensor, ...@@ -32,7 +32,7 @@ void set_tensor(paddle::lite::Tensor* tensor,
tensor->Resize(dims); tensor->Resize(dims);
std::vector<T> data; std::vector<T> data;
data.resize(production); data.resize(production);
for (size_t i = 0; i < production; ++i) { for (int i = 0; i < production; ++i) {
data[i] = i / 2.f; data[i] = i / 2.f;
} }
std::memcpy(tensor->mutable_data<T>(), data.data(), sizeof(T) * data.size()); std::memcpy(tensor->mutable_data<T>(), data.data(), sizeof(T) * data.size());
...@@ -53,7 +53,8 @@ TEST(CombinedParamsDesc, Scope) { ...@@ -53,7 +53,8 @@ TEST(CombinedParamsDesc, Scope) {
set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1})); set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1}));
// Set combined parameters // Set combined parameters
fbs::CombinedParamsDesc combined_param; fbs::CombinedParamsDesc combined_param;
SetCombinedParamsWithScope(scope, params_name, &combined_param); std::set<std::string> params_set(params_name.begin(), params_name.end());
SetCombinedParamsWithScope(scope, params_set, &combined_param);
/* --------- Check scope ---------- */ /* --------- Check scope ---------- */
auto check_params = [&](const CombinedParamsDescReadAPI& desc) { auto check_params = [&](const CombinedParamsDescReadAPI& desc) {
......
...@@ -103,6 +103,7 @@ GET_ATTRS_IMPL(std::vector<int64_t>, longs); ...@@ -103,6 +103,7 @@ GET_ATTRS_IMPL(std::vector<int64_t>, longs);
new proto::OpDesc_::AttrT())), \ new proto::OpDesc_::AttrT())), \
&(desc_->attrs)); \ &(desc_->attrs)); \
p->fb_f__ = v; \ p->fb_f__ = v; \
p->type = ConvertAttrType(OpDataTypeTrait<T>::AT); \
SetKey(name, &p); \ SetKey(name, &p); \
} }
ATTR_IMPL(int32_t, i); ATTR_IMPL(int32_t, i);
......
...@@ -115,7 +115,11 @@ class ParamDesc : public ParamDescAPI { ...@@ -115,7 +115,11 @@ class ParamDesc : public ParamDescAPI {
} }
explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) { explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT()); if (desc_->variable.type == proto::ParamDesc_::VariableDesc_NONE) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT());
}
CHECK(desc_->variable.type ==
proto::ParamDesc_::VariableDesc_LoDTensorDesc);
lod_tensor_ = desc_->variable.AsLoDTensorDesc(); lod_tensor_ = desc_->variable.AsLoDTensorDesc();
CHECK(lod_tensor_); CHECK(lod_tensor_);
} }
...@@ -169,7 +173,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -169,7 +173,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
} }
const ParamDescReadAPI* GetParamDesc(size_t idx) const override { const ParamDescReadAPI* GetParamDesc(size_t idx) const override {
return &params_[idx]; return params_[idx].get();
} }
size_t GetParamsSize() const override { return desc_.params.size(); } size_t GetParamsSize() const override { return desc_.params.size(); }
...@@ -178,7 +182,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -178,7 +182,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
desc_.params.push_back( desc_.params.push_back(
std::unique_ptr<proto::ParamDescT>(new proto::ParamDescT)); std::unique_ptr<proto::ParamDescT>(new proto::ParamDescT));
SyncParams(); SyncParams();
return &params_[params_.size() - 1]; return params_[params_.size() - 1].get();
} }
const void* data() { const void* data() {
...@@ -195,8 +199,8 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -195,8 +199,8 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
void SyncParams() { void SyncParams() {
params_.resize(GetParamsSize()); params_.resize(GetParamsSize());
for (size_t i = 0; i < GetParamsSize(); ++i) { for (size_t i = 0; i < GetParamsSize(); ++i) {
if (params_[i].raw_desc() != desc_.params[i].get()) { if (!params_[i] || params_[i]->raw_desc() != desc_.params[i].get()) {
params_[i] = ParamDesc(desc_.params[i].get()); params_[i].reset(new ParamDesc(desc_.params[i].get()));
} }
} }
} }
...@@ -212,7 +216,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -212,7 +216,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
flatbuffers::DetachedBuffer buf_; flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_; flatbuffers::FlatBufferBuilder fbb_;
proto::CombinedParamsDescT desc_; proto::CombinedParamsDescT desc_;
std::vector<ParamDesc> params_; std::vector<std::unique_ptr<ParamDesc>> params_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -21,21 +21,21 @@ namespace fbs { ...@@ -21,21 +21,21 @@ namespace fbs {
template <> template <>
proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>( proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>(
int32_t idx) const { int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
return desc_->blocks()->Get(idx); return desc_->blocks()->Get(idx);
} }
template <> template <>
BlockDescView const* ProgramDescView::GetBlock<BlockDescView>( BlockDescView const* ProgramDescView::GetBlock<BlockDescView>(
int32_t idx) const { int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
return &blocks_[idx]; return &blocks_[idx];
} }
template <> template <>
proto::BlockDescT* ProgramDesc::GetBlock<proto::BlockDescT>(int32_t idx) { proto::BlockDescT* ProgramDesc::GetBlock<proto::BlockDescT>(int32_t idx) {
CHECK_LT(idx, BlocksSize()) << "idx >= vars.size()"; CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= vars.size()";
return blocks_[idx].raw_desc(); return blocks_[idx]->raw_desc();
} }
template <> template <>
...@@ -43,7 +43,7 @@ proto::BlockDescT* ProgramDesc::AddBlock<proto::BlockDescT>() { ...@@ -43,7 +43,7 @@ proto::BlockDescT* ProgramDesc::AddBlock<proto::BlockDescT>() {
desc_.blocks.push_back( desc_.blocks.push_back(
std::unique_ptr<proto::BlockDescT>(new proto::BlockDescT)); std::unique_ptr<proto::BlockDescT>(new proto::BlockDescT));
SyncBlocks(); SyncBlocks();
return blocks_.back().raw_desc(); return blocks_.back()->raw_desc();
} }
} // namespace fbs } // namespace fbs
......
...@@ -150,8 +150,8 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -150,8 +150,8 @@ class ProgramDesc : public ProgramDescAPI {
void SyncBlocks() { void SyncBlocks() {
blocks_.resize(desc_.blocks.size()); blocks_.resize(desc_.blocks.size());
for (size_t i = 0; i < desc_.blocks.size(); ++i) { for (size_t i = 0; i < desc_.blocks.size(); ++i) {
if (blocks_[i].raw_desc() != desc_.blocks[i].get()) { if (!blocks_[i] || blocks_[i]->raw_desc() != desc_.blocks[i].get()) {
blocks_[i] = BlockDesc(desc_.blocks[i].get()); blocks_[i].reset(new BlockDesc(desc_.blocks[i].get()));
} }
} }
} }
...@@ -167,7 +167,7 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -167,7 +167,7 @@ class ProgramDesc : public ProgramDescAPI {
flatbuffers::DetachedBuffer buf_; flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_; flatbuffers::FlatBufferBuilder fbb_;
proto::ProgramDescT desc_; proto::ProgramDescT desc_;
std::vector<BlockDesc> blocks_; std::vector<std::unique_ptr<BlockDesc>> blocks_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -15,136 +15,22 @@ ...@@ -15,136 +15,22 @@
#include "lite/model_parser/flatbuffers/program_desc.h" #include "lite/model_parser/flatbuffers/program_desc.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <string>
#include "lite/model_parser/flatbuffers/test_helper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
namespace {
std::vector<char> GenerateProgramCache() {
/* --------- Set Program --------- */
ProgramDesc program;
program.SetVersion(1000600);
/* --------- Set Block A --------- */
BlockDesc block_a(program.AddBlock<proto::BlockDescT>());
VarDesc var_a2(block_a.AddVar<proto::VarDescT>());
var_a2.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a2.SetName("var_a2");
var_a2.SetShape({2, 2, 1});
VarDesc var_a0(block_a.AddVar<proto::VarDescT>());
var_a0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a0.SetName("var_a0");
var_a0.SetShape({1, 2});
OpDesc op_a0(block_a.AddOp<proto::OpDescT>());
op_a0.SetType("Type");
op_a0.SetInput("X", {"var_a0"});
op_a0.SetOutput("Y0", {"var_a0", "var_a1"});
op_a0.SetOutput("Y1", {"var_a2"});
op_a0.SetAttr<std::string>("Attr5", "attr_5");
op_a0.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_a0.SetAttr<float>("Attr1", 0.98f);
op_a0.SetAttr<int32_t>("Attr0", 16);
/* --------- Set Block B --------- */
BlockDesc block_b(program.AddBlock<proto::BlockDescT>());
VarDesc var_b0(block_b.AddVar<proto::VarDescT>());
var_b0.SetName("var_b0");
var_b0.SetShape({-1, 1});
OpDesc op_b0(block_b.AddOp<proto::OpDescT>());
op_b0.SetType("Type0");
op_b0.SetInput("X", {"var_b0"});
op_b0.SetOutput("Y1", {"var_b0"});
op_b0.SetAttr<std::string>("Attr5", "attr_5");
OpDesc op_b1(block_b.AddOp<proto::OpDescT>());
op_b1.SetType("Type1");
op_b1.SetInput("X", {"var_b0"});
op_b1.SetOutput("Y1", {"var_b0"});
op_b1.SetAttr<std::string>("Attr5", "attr_5");
op_b1.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */
std::vector<char> cache;
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
}
} // namespace
TEST(ProgramDesc, LoadTest) { TEST(ProgramDesc, LoadTest) {
ProgramDesc program(GenerateProgramCache()); ProgramDesc program(test::GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600); test::CheckProgramCache(&program);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
auto block_a = BlockDesc(program.GetBlock<proto::BlockDescT>(0));
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
auto var_a2 = VarDesc(block_a.GetVar<proto::VarDescT>(0));
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
auto op_a0 = OpDesc(block_a.GetOp<proto::OpDescT>(0));
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(op_a0.GetAttr<std::vector<std::string>>("Attr2") ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
auto block_b = BlockDesc(program.GetBlock<proto::BlockDescT>(1));
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
auto op_b0 = OpDesc(block_b.GetOp<proto::OpDescT>(1));
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
} }
TEST(ProgramDescView, LoadTest) { TEST(ProgramDescView, LoadTest) {
const ProgramDescView program(GenerateProgramCache()); const ProgramDescView program(test::GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600); test::CheckProgramCache(program);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
const auto& block_a = *program.GetBlock<BlockDescView>(0);
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
const auto& var_a2 = *block_a.GetVar<VarDescView>(0);
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
const auto& op_a0 = *block_a.GetOp<OpDescView>(0);
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(static_cast<std::vector<std::string>>(
op_a0.GetAttr<std::vector<std::string>>("Attr2")) ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
const auto& block_b = *program.GetBlock<BlockDescView>(1);
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
const auto& op_b0 = *block_b.GetOp<OpDescView>(1);
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
} }
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
namespace test {
inline std::vector<char> GenerateProgramCache() {
/* --------- Set Program --------- */
ProgramDesc program;
program.SetVersion(1000600);
/* --------- Set Block A --------- */
BlockDesc block_a(program.AddBlock<proto::BlockDescT>());
VarDesc var_a2(block_a.AddVar<proto::VarDescT>());
var_a2.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a2.SetName("var_a2");
var_a2.SetShape({2, 2, 1});
VarDesc var_a0(block_a.AddVar<proto::VarDescT>());
var_a0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a0.SetName("var_a0");
var_a0.SetShape({1, 2});
OpDesc op_a0(block_a.AddOp<proto::OpDescT>());
op_a0.SetType("Type");
op_a0.SetInput("X", {"var_a0"});
op_a0.SetOutput("Y0", {"var_a0", "var_a1"});
op_a0.SetOutput("Y1", {"var_a2"});
op_a0.SetAttr<std::string>("Attr5", "attr_5");
op_a0.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_a0.SetAttr<float>("Attr1", 0.98f);
op_a0.SetAttr<int32_t>("Attr0", 16);
/* --------- Set Block B --------- */
BlockDesc block_b(program.AddBlock<proto::BlockDescT>());
VarDesc var_b0(block_b.AddVar<proto::VarDescT>());
var_b0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_b0.SetName("var_b0");
var_b0.SetShape({-1, 1});
OpDesc op_b0(block_b.AddOp<proto::OpDescT>());
op_b0.SetType("Type0");
op_b0.SetInput("X", {"var_b0"});
op_b0.SetOutput("Y1", {"var_b0"});
op_b0.SetAttr<std::string>("Attr5", "attr_5");
OpDesc op_b1(block_b.AddOp<proto::OpDescT>());
op_b1.SetType("Type1");
op_b1.SetInput("X", {"var_b0"});
op_b1.SetOutput("Y1", {"var_b0"});
op_b1.SetAttr<std::string>("Attr5", "attr_5");
op_b1.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */
std::vector<char> cache;
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
}
inline void CheckProgramCache(ProgramDesc* program) {
CHECK_EQ(program->Version(), 1000600);
CHECK_EQ(program->BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
BlockDesc block_a(program->GetBlock<proto::BlockDescT>(0));
CHECK_EQ(block_a.OpsSize(), static_cast<size_t>(1));
CHECK_EQ(block_a.VarsSize(), static_cast<size_t>(2));
auto var_a2 = VarDesc(block_a.GetVar<proto::VarDescT>(0));
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
auto op_a0 = OpDesc(block_a.GetOp<proto::OpDescT>(0));
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(op_a0.GetAttr<std::vector<std::string>>("Attr2") ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
BlockDesc block_b(program->GetBlock<proto::BlockDescT>(1));
CHECK_EQ(block_b.OpsSize(), static_cast<size_t>(2));
CHECK_EQ(block_b.VarsSize(), static_cast<size_t>(1));
auto op_b0 = OpDesc(block_b.GetOp<proto::OpDescT>(1));
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
inline void CheckProgramCache(const ProgramDescView& program) {
CHECK_EQ(program.Version(), 1000600);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
const auto& block_a = *program.GetBlock<BlockDescView>(0);
CHECK_EQ(block_a.OpsSize(), static_cast<size_t>(1));
CHECK_EQ(block_a.VarsSize(), static_cast<size_t>(2));
const auto& var_a2 = *block_a.GetVar<VarDescView>(0);
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
const auto& op_a0 = *block_a.GetOp<OpDescView>(0);
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(static_cast<std::vector<std::string>>(
op_a0.GetAttr<std::vector<std::string>>("Attr2")) ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
const auto& block_b = *program.GetBlock<BlockDescView>(1);
CHECK_EQ(block_b.OpsSize(), static_cast<size_t>(2));
CHECK_EQ(block_b.VarsSize(), static_cast<size_t>(1));
const auto& op_b0 = *block_b.GetOp<OpDescView>(1);
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
} // namespace test
} // namespace fbs
} // namespace lite
} // namespace paddle
...@@ -93,9 +93,14 @@ class VarDesc : public VarDescAPI { ...@@ -93,9 +93,14 @@ class VarDesc : public VarDescAPI {
Type GetType() const override { return ConvertVarType(type_->type); } Type GetType() const override { return ConvertVarType(type_->type); }
void SetType(Type type) override { void SetType(Type type) override { type_->type = ConvertVarType(type); }
CHECK(type == VarDescAPI::Type::LOD_TENSOR);
type_->type = ConvertVarType(type); void SetDataType(Type type) {
type_->lod_tensor->tensor->data_type = ConvertVarType(type);
}
Type GetDataType() const {
return ConvertVarType(type_->lod_tensor->tensor->data_type);
} }
bool Persistable() const override { return desc_->persistable; } bool Persistable() const override { return desc_->persistable; }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <fstream> #include <fstream>
#include <limits> #include <limits>
#include <set> #include <set>
#include "lite/core/scope.h" #include "lite/core/scope.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/core/variable.h" #include "lite/core/variable.h"
...@@ -27,6 +28,7 @@ ...@@ -27,6 +28,7 @@
#include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h"
#include "lite/model_parser/naive_buffer/var_desc.h" #include "lite/model_parser/naive_buffer/var_desc.h"
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
#include "lite/model_parser/flatbuffers/io.h"
#include "lite/model_parser/pb/program_desc.h" #include "lite/model_parser/pb/program_desc.h"
#include "lite/model_parser/pb/var_desc.h" #include "lite/model_parser/pb/var_desc.h"
#endif #endif
...@@ -592,7 +594,54 @@ void SaveModelNaive(const std::string &model_dir, ...@@ -592,7 +594,54 @@ void SaveModelNaive(const std::string &model_dir,
LOG(INFO) << "Save naive buffer model in '" << model_dir LOG(INFO) << "Save naive buffer model in '" << model_dir
<< ".nb' successfully"; << ".nb' successfully";
} }
#endif
/* ---------- Flatbuffers ---------- */
void SaveModelFbs(const std::string &model_dir,
const Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
/* 1. Save model to model.fbs */
const std::string prog_path = model_dir + "/model.fbs";
fbs::ProgramDesc fbs_prog;
TransformProgramDescCppToAny(cpp_prog, &fbs_prog);
fbs::SaveFile(prog_path, fbs_prog.data(), fbs_prog.buf_size());
/* 2. Get param names from cpp::ProgramDesc */
auto &main_block_desc = *cpp_prog.GetBlock<cpp::BlockDesc>(0);
// set unique_var_names to avoid saving shared params repeatedly
std::set<std::string> unique_var_names;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable() ||
unique_var_names.count(var.Name()) > 0)
continue;
unique_var_names.emplace(var.Name());
}
/* 3. Save combined params to params.fbs */
const std::string params_path = model_dir + "/params.fbs";
fbs::CombinedParamsDesc params_prog;
fbs::SetCombinedParamsWithScope(exec_scope, unique_var_names, &params_prog);
fbs::SaveFile(params_path, params_prog.data(), params_prog.buf_size());
}
void LoadModelFbsFromFile(const std::string &filename,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
CHECK(cpp_prog);
CHECK(scope);
/* 1. Save cpp::ProgramDesc with model.fbs */
const std::string prog_path = filename + "/model.fbs";
fbs::ProgramDesc program(fbs::LoadFile(prog_path));
TransformProgramDescAnyToCpp(program, cpp_prog);
/* 2. Save scope with params.fbs */
const std::string params_path = filename + "/params.fbs";
fbs::CombinedParamsDesc params(fbs::LoadFile(params_path));
fbs::SetScopeWithCombinedParams(scope, params);
}
#endif // LITE_ON_TINY_PUBLISH
template <typename T> template <typename T>
void SetTensorDataNaive(T *out, size_t size, const std::vector<T> &src) { void SetTensorDataNaive(T *out, size_t size, const std::vector<T> &src) {
......
...@@ -88,7 +88,15 @@ void SaveModelNaive(const std::string& model_dir, ...@@ -88,7 +88,15 @@ void SaveModelNaive(const std::string& model_dir,
const Scope& exec_scope, const Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog, const cpp::ProgramDesc& cpp_prog,
bool combined = true); bool combined = true);
#endif
void SaveModelFbs(const std::string& model_dir,
const Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog);
void LoadModelFbsFromFile(const std::string& filename,
Scope* scope,
cpp::ProgramDesc* cpp_prog);
#endif // LITE_ON_TINY_PUBLISH
void LoadParamNaive(const std::string& path, void LoadParamNaive(const std::string& path,
lite::Scope* scope, lite::Scope* scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册