未验证 提交 0ac48fcb 编写于 作者: 石晓伟 提交者: GitHub

Add the conversion of cpp and flatbuffers program, test=develop (#4079)

* update compatible_pb.cc, test=develop

* using unique_ptr in fbs::desc, test=develop

* add fbs to compatible_pb, test=develop

* update model_parser functions, test=develop

* fix bugs, test=develop
上级 9b9245d9
......@@ -19,9 +19,9 @@ endif()
if (NOT LITE_ON_TINY_PUBLISH)
lite_cc_library(compatible_pb SRCS compatible_pb.cc
DEPS ${cpp_wrapper} ${naive_wrapper} ${pb_wrapper} framework_proto)
DEPS ${cpp_wrapper} ${naive_wrapper} ${pb_wrapper} framework_proto fbs_io)
else()
lite_cc_library(compatible_pb SRCS compatible_pb.cc DEPS ${cpp_wrapper} ${naive_wrapper})
lite_cc_library(compatible_pb SRCS compatible_pb.cc DEPS ${cpp_wrapper} ${naive_wrapper} fbs_io)
endif()
lite_cc_library(model_parser SRCS model_parser.cc DEPS
......
......@@ -15,6 +15,7 @@
#include "lite/model_parser/compatible_pb.h"
#include <string>
#include <vector>
#include "lite/model_parser/flatbuffers/program_desc.h"
#include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h"
......@@ -73,6 +74,18 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
}*/
}
template <>
void TransformVarDescAnyToCpp<fbs::VarDesc>(const fbs::VarDesc &any_desc,
cpp::VarDesc *cpp_desc) {
cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable());
if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}
}
/// For OpDesc transform
template <typename OpDescType>
void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) {
......@@ -219,34 +232,34 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
}
/// For BlockDesc transform
#define TRANS_BLOCK_ANY_WITH_CPP_IMPL(T, NT, PNT) \
#define TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpT, VarT, NT, PNT) \
template <> \
void TransformBlockDescAnyToCpp<NT::T>(const NT::T &any_desc, \
cpp::BlockDesc *cpp_desc) { \
NT::T desc = any_desc; \
void TransformBlockDescAnyToCpp<NT::BlockDesc>( \
const NT::BlockDesc &any_desc, cpp::BlockDesc *cpp_desc) { \
NT::BlockDesc &desc = const_cast<NT::BlockDesc &>(any_desc); \
cpp_desc->SetIdx(desc.Idx()); \
cpp_desc->SetParentIdx(desc.ParentIdx()); \
cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \
\
cpp_desc->ClearOps(); \
for (size_t i = 0; i < desc.OpsSize(); ++i) { \
auto any_op_desc = NT::OpDesc(desc.GetOp<PNT::proto::OpDesc>(i)); \
auto any_op_desc = NT::OpDesc(desc.GetOp<PNT::proto::OpT>(i)); \
auto *cpp_op_desc = cpp_desc->AddOp<cpp::OpDesc>(); \
TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \
} \
\
cpp_desc->ClearVars(); \
for (size_t i = 0; i < desc.VarsSize(); ++i) { \
auto any_var_desc = NT::VarDesc(desc.GetVar<PNT::proto::VarDesc>(i)); \
auto any_var_desc = NT::VarDesc(desc.GetVar<PNT::proto::VarT>(i)); \
auto *cpp_var_desc = cpp_desc->AddVar<cpp::VarDesc>(); \
TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \
} \
} \
\
template <> \
void TransformBlockDescCppToAny<NT::T>(const cpp::T &cpp_desc, \
NT::T *any_desc) { \
const cpp::T &desc = cpp_desc; \
void TransformBlockDescCppToAny<NT::BlockDesc>( \
const cpp::BlockDesc &cpp_desc, NT::BlockDesc *any_desc) { \
const cpp::BlockDesc &desc = cpp_desc; \
any_desc->SetIdx(desc.Idx()); \
any_desc->SetParentIdx(desc.ParentIdx()); \
any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \
......@@ -254,41 +267,39 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
any_desc->ClearOps(); \
for (size_t i = 0; i < desc.OpsSize(); ++i) { \
auto *cpp_op_desc = desc.GetOp<cpp::OpDesc>(i); \
auto any_op_desc = NT::OpDesc(any_desc->AddOp<PNT::proto::OpDesc>()); \
auto any_op_desc = NT::OpDesc(any_desc->AddOp<PNT::proto::OpT>()); \
TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \
} \
\
any_desc->ClearVars(); \
for (size_t i = 0; i < desc.VarsSize(); ++i) { \
auto *cpp_var_desc = desc.GetVar<cpp::VarDesc>(i); \
auto any_var_desc = \
NT::VarDesc(any_desc->AddVar<PNT::proto::VarDesc>()); \
auto any_var_desc = NT::VarDesc(any_desc->AddVar<PNT::proto::VarT>()); \
TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \
} \
}
/// For ProgramDesc transform
#define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(T, NT, PNT) \
#define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockT, NT, PNT) \
template <> \
void TransformProgramDescAnyToCpp<NT::T>(const NT::T &any_desc, \
cpp::ProgramDesc *cpp_desc) { \
NT::T desc = any_desc; \
void TransformProgramDescAnyToCpp<NT::ProgramDesc>( \
const NT::ProgramDesc &any_desc, cpp::ProgramDesc *cpp_desc) { \
NT::ProgramDesc &desc = const_cast<NT::ProgramDesc &>(any_desc); \
if (desc.HasVersion()) { \
cpp_desc->SetVersion(desc.Version()); \
} \
\
cpp_desc->ClearBlocks(); \
for (size_t i = 0; i < desc.BlocksSize(); ++i) { \
auto any_block_desc = \
NT::BlockDesc(desc.GetBlock<PNT::proto::BlockDesc>(i)); \
NT::BlockDesc any_block_desc(desc.GetBlock<PNT::proto::BlockT>(i)); \
auto *cpp_block_desc = cpp_desc->AddBlock<cpp::BlockDesc>(); \
TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \
} \
} \
\
template <> \
void TransformProgramDescCppToAny<NT::T>(const cpp::T &cpp_desc, \
NT::T *any_desc) { \
void TransformProgramDescCppToAny<NT::ProgramDesc>( \
const cpp::ProgramDesc &cpp_desc, NT::ProgramDesc *any_desc) { \
auto &desc = cpp_desc; \
if (desc.HasVersion()) { \
any_desc->SetVersion(desc.Version()); \
......@@ -297,22 +308,26 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
any_desc->ClearBlocks(); \
for (size_t i = 0; i < desc.BlocksSize(); ++i) { \
auto *cpp_block_desc = desc.GetBlock<cpp::BlockDesc>(i); \
auto any_block_desc = \
NT::BlockDesc(any_desc->AddBlock<PNT::proto::BlockDesc>()); \
NT::BlockDesc any_block_desc(any_desc->AddBlock<PNT::proto::BlockT>()); \
TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \
} \
}
TRANS_VAR_ANY_WITH_CPP_IMPL(naive_buffer::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(naive_buffer::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, naive_buffer, naive_buffer);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, naive_buffer, naive_buffer);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDesc, VarDesc, naive_buffer, naive_buffer);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDesc, naive_buffer, naive_buffer);
TRANS_VAR_ANY_WITH_CPP_IMPL(fbs::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(fbs::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDescT, VarDescT, fbs, fbs);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDescT, fbs, fbs);
#ifndef LITE_ON_TINY_PUBLISH
TRANS_VAR_ANY_WITH_CPP_IMPL(pb::VarDesc);
TRANS_OP_ANY_WITH_CPP_IMPL(pb::OpDesc);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, pb, framework);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, pb, framework);
TRANS_BLOCK_ANY_WITH_CPP_IMPL(OpDesc, VarDesc, pb, framework);
TRANS_PROGRAM_ANY_WITH_CPP_IMPL(BlockDesc, pb, framework);
#endif
#undef TRANS_VAR_ANY_WITH_CPP_IMPL
......
......@@ -15,6 +15,8 @@
#include "lite/model_parser/compatible_pb.h"
#include <gtest/gtest.h>
#include "lite/model_parser/cpp_desc.h"
#include "lite/model_parser/flatbuffers/program_desc.h"
#include "lite/model_parser/flatbuffers/test_helper.h"
#include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h"
......@@ -430,5 +432,14 @@ TEST(ProgramDesc, AnyToCpp) {
TestProgramAnyToCpp<naive_buffer::ProgramDesc>(&nb_desc);
}
TEST(ProgramDesc, FbsCpp) {
fbs::ProgramDesc fbs_program(fbs::test::GenerateProgramCache());
cpp::ProgramDesc cpp_program;
TransformProgramDescAnyToCpp(fbs_program, &cpp_program);
fbs::ProgramDesc fbs_program_2;
TransformProgramDescCppToAny(cpp_program, &fbs_program_2);
fbs::test::CheckProgramCache(&fbs_program_2);
}
} // namespace lite
} // namespace paddle
......@@ -21,52 +21,52 @@ namespace fbs {
template <>
proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return desc_->vars()->Get(idx);
}
template <>
proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
return desc_->ops()->Get(idx);
}
template <>
VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
return &ops_[idx];
}
template <>
proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return vars_[idx].raw_desc();
CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
return vars_[idx]->raw_desc();
}
template <>
proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() {
desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT));
SyncVars();
return vars_.back().raw_desc();
return vars_.back()->raw_desc();
}
template <>
proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= vars.size()";
return ops_[idx].raw_desc();
CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= vars.size()";
return ops_[idx]->raw_desc();
}
template <>
proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() {
desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT));
SyncOps();
return ops_.back().raw_desc();
return ops_.back()->raw_desc();
}
} // namespace fbs
......
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <vector>
#include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
......@@ -150,24 +151,24 @@ class BlockDesc : public BlockDescAPI {
void SyncVars() {
vars_.resize(desc_->vars.size());
for (size_t i = 0; i < desc_->vars.size(); ++i) {
if (vars_[i].raw_desc() != desc_->vars[i].get()) {
vars_[i] = VarDesc(desc_->vars[i].get());
if (!vars_[i] || vars_[i]->raw_desc() != desc_->vars[i].get()) {
vars_[i].reset(new VarDesc(desc_->vars[i].get()));
}
}
}
void SyncOps() {
ops_.resize(desc_->ops.size());
for (size_t i = 0; i < desc_->ops.size(); ++i) {
if (ops_[i].raw_desc() != desc_->ops[i].get()) {
ops_[i] = OpDesc(desc_->ops[i].get());
if (!ops_[i] || ops_[i]->raw_desc() != desc_->ops[i].get()) {
ops_[i].reset(new OpDesc(desc_->ops[i].get()));
}
}
}
bool owned_{false};
proto::BlockDescT* desc_{nullptr};
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
std::vector<std::unique_ptr<VarDesc>> vars_;
std::vector<std::unique_ptr<OpDesc>> ops_;
};
} // namespace fbs
......
......@@ -25,11 +25,12 @@ namespace fbs {
std::vector<char> LoadFile(const std::string& path) {
FILE* file = fopen(path.c_str(), "rb");
CHECK(file);
fseek(file, 0, SEEK_END);
int64_t length = ftell(file);
uint64_t length = ftell(file);
rewind(file);
std::vector<char> buf(length);
CHECK(fread(buf.data(), 1, length, file) == length);
CHECK_EQ(fread(buf.data(), 1, length, file), length);
fclose(file);
return buf;
}
......@@ -37,6 +38,7 @@ std::vector<char> LoadFile(const std::string& path) {
void SaveFile(const std::string& path, const void* src, size_t byte_size) {
CHECK(src);
FILE* file = fopen(path.c_str(), "wb");
CHECK(file);
CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size);
fclose(file);
}
......@@ -60,7 +62,7 @@ void SetTensorWithParam(lite::Tensor* tensor, const ParamDescReadAPI& param) {
}
void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name,
const std::set<std::string>& params_name,
CombinedParamsDescWriteAPI* params) {
for (const auto& name : params_name) {
auto* param = params->AddParamDesc();
......
......@@ -14,6 +14,7 @@
#pragma once
#include <set>
#include <string>
#include <vector>
#include "lite/core/scope.h"
......@@ -30,8 +31,9 @@ void SaveFile(const std::string& path, const void* src, size_t byte_size);
void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params);
void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name,
const std::set<std::string>& params_name,
CombinedParamsDescWriteAPI* params);
} // namespace fbs
......
......@@ -32,7 +32,7 @@ void set_tensor(paddle::lite::Tensor* tensor,
tensor->Resize(dims);
std::vector<T> data;
data.resize(production);
for (size_t i = 0; i < production; ++i) {
for (int i = 0; i < production; ++i) {
data[i] = i / 2.f;
}
std::memcpy(tensor->mutable_data<T>(), data.data(), sizeof(T) * data.size());
......@@ -53,7 +53,8 @@ TEST(CombinedParamsDesc, Scope) {
set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1}));
// Set combined parameters
fbs::CombinedParamsDesc combined_param;
SetCombinedParamsWithScope(scope, params_name, &combined_param);
std::set<std::string> params_set(params_name.begin(), params_name.end());
SetCombinedParamsWithScope(scope, params_set, &combined_param);
/* --------- Check scope ---------- */
auto check_params = [&](const CombinedParamsDescReadAPI& desc) {
......
......@@ -103,6 +103,7 @@ GET_ATTRS_IMPL(std::vector<int64_t>, longs);
new proto::OpDesc_::AttrT())), \
&(desc_->attrs)); \
p->fb_f__ = v; \
p->type = ConvertAttrType(OpDataTypeTrait<T>::AT); \
SetKey(name, &p); \
}
ATTR_IMPL(int32_t, i);
......
......@@ -115,7 +115,11 @@ class ParamDesc : public ParamDescAPI {
}
explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) {
if (desc_->variable.type == proto::ParamDesc_::VariableDesc_NONE) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT());
}
CHECK(desc_->variable.type ==
proto::ParamDesc_::VariableDesc_LoDTensorDesc);
lod_tensor_ = desc_->variable.AsLoDTensorDesc();
CHECK(lod_tensor_);
}
......@@ -169,7 +173,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
}
const ParamDescReadAPI* GetParamDesc(size_t idx) const override {
return &params_[idx];
return params_[idx].get();
}
size_t GetParamsSize() const override { return desc_.params.size(); }
......@@ -178,7 +182,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
desc_.params.push_back(
std::unique_ptr<proto::ParamDescT>(new proto::ParamDescT));
SyncParams();
return &params_[params_.size() - 1];
return params_[params_.size() - 1].get();
}
const void* data() {
......@@ -195,8 +199,8 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
void SyncParams() {
params_.resize(GetParamsSize());
for (size_t i = 0; i < GetParamsSize(); ++i) {
if (params_[i].raw_desc() != desc_.params[i].get()) {
params_[i] = ParamDesc(desc_.params[i].get());
if (!params_[i] || params_[i]->raw_desc() != desc_.params[i].get()) {
params_[i].reset(new ParamDesc(desc_.params[i].get()));
}
}
}
......@@ -212,7 +216,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_;
proto::CombinedParamsDescT desc_;
std::vector<ParamDesc> params_;
std::vector<std::unique_ptr<ParamDesc>> params_;
};
} // namespace fbs
......
......@@ -21,21 +21,21 @@ namespace fbs {
template <>
proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
return desc_->blocks()->Get(idx);
}
template <>
BlockDescView const* ProgramDescView::GetBlock<BlockDescView>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
return &blocks_[idx];
}
template <>
proto::BlockDescT* ProgramDesc::GetBlock<proto::BlockDescT>(int32_t idx) {
CHECK_LT(idx, BlocksSize()) << "idx >= vars.size()";
return blocks_[idx].raw_desc();
CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= vars.size()";
return blocks_[idx]->raw_desc();
}
template <>
......@@ -43,7 +43,7 @@ proto::BlockDescT* ProgramDesc::AddBlock<proto::BlockDescT>() {
desc_.blocks.push_back(
std::unique_ptr<proto::BlockDescT>(new proto::BlockDescT));
SyncBlocks();
return blocks_.back().raw_desc();
return blocks_.back()->raw_desc();
}
} // namespace fbs
......
......@@ -150,8 +150,8 @@ class ProgramDesc : public ProgramDescAPI {
void SyncBlocks() {
blocks_.resize(desc_.blocks.size());
for (size_t i = 0; i < desc_.blocks.size(); ++i) {
if (blocks_[i].raw_desc() != desc_.blocks[i].get()) {
blocks_[i] = BlockDesc(desc_.blocks[i].get());
if (!blocks_[i] || blocks_[i]->raw_desc() != desc_.blocks[i].get()) {
blocks_[i].reset(new BlockDesc(desc_.blocks[i].get()));
}
}
}
......@@ -167,7 +167,7 @@ class ProgramDesc : public ProgramDescAPI {
flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_;
proto::ProgramDescT desc_;
std::vector<BlockDesc> blocks_;
std::vector<std::unique_ptr<BlockDesc>> blocks_;
};
} // namespace fbs
......
......@@ -15,136 +15,22 @@
#include "lite/model_parser/flatbuffers/program_desc.h"
#include <gtest/gtest.h>
#include <string>
#include "lite/model_parser/flatbuffers/test_helper.h"
namespace paddle {
namespace lite {
namespace fbs {
namespace {
std::vector<char> GenerateProgramCache() {
/* --------- Set Program --------- */
ProgramDesc program;
program.SetVersion(1000600);
/* --------- Set Block A --------- */
BlockDesc block_a(program.AddBlock<proto::BlockDescT>());
VarDesc var_a2(block_a.AddVar<proto::VarDescT>());
var_a2.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a2.SetName("var_a2");
var_a2.SetShape({2, 2, 1});
VarDesc var_a0(block_a.AddVar<proto::VarDescT>());
var_a0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a0.SetName("var_a0");
var_a0.SetShape({1, 2});
OpDesc op_a0(block_a.AddOp<proto::OpDescT>());
op_a0.SetType("Type");
op_a0.SetInput("X", {"var_a0"});
op_a0.SetOutput("Y0", {"var_a0", "var_a1"});
op_a0.SetOutput("Y1", {"var_a2"});
op_a0.SetAttr<std::string>("Attr5", "attr_5");
op_a0.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_a0.SetAttr<float>("Attr1", 0.98f);
op_a0.SetAttr<int32_t>("Attr0", 16);
/* --------- Set Block B --------- */
BlockDesc block_b(program.AddBlock<proto::BlockDescT>());
VarDesc var_b0(block_b.AddVar<proto::VarDescT>());
var_b0.SetName("var_b0");
var_b0.SetShape({-1, 1});
OpDesc op_b0(block_b.AddOp<proto::OpDescT>());
op_b0.SetType("Type0");
op_b0.SetInput("X", {"var_b0"});
op_b0.SetOutput("Y1", {"var_b0"});
op_b0.SetAttr<std::string>("Attr5", "attr_5");
OpDesc op_b1(block_b.AddOp<proto::OpDescT>());
op_b1.SetType("Type1");
op_b1.SetInput("X", {"var_b0"});
op_b1.SetOutput("Y1", {"var_b0"});
op_b1.SetAttr<std::string>("Attr5", "attr_5");
op_b1.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */
std::vector<char> cache;
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
}
} // namespace
TEST(ProgramDesc, LoadTest) {
ProgramDesc program(GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
auto block_a = BlockDesc(program.GetBlock<proto::BlockDescT>(0));
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
auto var_a2 = VarDesc(block_a.GetVar<proto::VarDescT>(0));
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
auto op_a0 = OpDesc(block_a.GetOp<proto::OpDescT>(0));
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(op_a0.GetAttr<std::vector<std::string>>("Attr2") ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
auto block_b = BlockDesc(program.GetBlock<proto::BlockDescT>(1));
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
auto op_b0 = OpDesc(block_b.GetOp<proto::OpDescT>(1));
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
ProgramDesc program(test::GenerateProgramCache());
test::CheckProgramCache(&program);
}
TEST(ProgramDescView, LoadTest) {
const ProgramDescView program(GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
const auto& block_a = *program.GetBlock<BlockDescView>(0);
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
const auto& var_a2 = *block_a.GetVar<VarDescView>(0);
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
const auto& op_a0 = *block_a.GetOp<OpDescView>(0);
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(static_cast<std::vector<std::string>>(
op_a0.GetAttr<std::vector<std::string>>("Attr2")) ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
const auto& block_b = *program.GetBlock<BlockDescView>(1);
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
const auto& op_b0 = *block_b.GetOp<OpDescView>(1);
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
const ProgramDescView program(test::GenerateProgramCache());
test::CheckProgramCache(program);
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
namespace test {
inline std::vector<char> GenerateProgramCache() {
/* --------- Set Program --------- */
ProgramDesc program;
program.SetVersion(1000600);
/* --------- Set Block A --------- */
BlockDesc block_a(program.AddBlock<proto::BlockDescT>());
VarDesc var_a2(block_a.AddVar<proto::VarDescT>());
var_a2.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a2.SetName("var_a2");
var_a2.SetShape({2, 2, 1});
VarDesc var_a0(block_a.AddVar<proto::VarDescT>());
var_a0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a0.SetName("var_a0");
var_a0.SetShape({1, 2});
OpDesc op_a0(block_a.AddOp<proto::OpDescT>());
op_a0.SetType("Type");
op_a0.SetInput("X", {"var_a0"});
op_a0.SetOutput("Y0", {"var_a0", "var_a1"});
op_a0.SetOutput("Y1", {"var_a2"});
op_a0.SetAttr<std::string>("Attr5", "attr_5");
op_a0.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_a0.SetAttr<float>("Attr1", 0.98f);
op_a0.SetAttr<int32_t>("Attr0", 16);
/* --------- Set Block B --------- */
BlockDesc block_b(program.AddBlock<proto::BlockDescT>());
VarDesc var_b0(block_b.AddVar<proto::VarDescT>());
var_b0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_b0.SetName("var_b0");
var_b0.SetShape({-1, 1});
OpDesc op_b0(block_b.AddOp<proto::OpDescT>());
op_b0.SetType("Type0");
op_b0.SetInput("X", {"var_b0"});
op_b0.SetOutput("Y1", {"var_b0"});
op_b0.SetAttr<std::string>("Attr5", "attr_5");
OpDesc op_b1(block_b.AddOp<proto::OpDescT>());
op_b1.SetType("Type1");
op_b1.SetInput("X", {"var_b0"});
op_b1.SetOutput("Y1", {"var_b0"});
op_b1.SetAttr<std::string>("Attr5", "attr_5");
op_b1.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */
std::vector<char> cache;
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
}
inline void CheckProgramCache(ProgramDesc* program) {
CHECK_EQ(program->Version(), 1000600);
CHECK_EQ(program->BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
BlockDesc block_a(program->GetBlock<proto::BlockDescT>(0));
CHECK_EQ(block_a.OpsSize(), static_cast<size_t>(1));
CHECK_EQ(block_a.VarsSize(), static_cast<size_t>(2));
auto var_a2 = VarDesc(block_a.GetVar<proto::VarDescT>(0));
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
auto op_a0 = OpDesc(block_a.GetOp<proto::OpDescT>(0));
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(op_a0.GetAttr<std::vector<std::string>>("Attr2") ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
BlockDesc block_b(program->GetBlock<proto::BlockDescT>(1));
CHECK_EQ(block_b.OpsSize(), static_cast<size_t>(2));
CHECK_EQ(block_b.VarsSize(), static_cast<size_t>(1));
auto op_b0 = OpDesc(block_b.GetOp<proto::OpDescT>(1));
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
inline void CheckProgramCache(const ProgramDescView& program) {
CHECK_EQ(program.Version(), 1000600);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
const auto& block_a = *program.GetBlock<BlockDescView>(0);
CHECK_EQ(block_a.OpsSize(), static_cast<size_t>(1));
CHECK_EQ(block_a.VarsSize(), static_cast<size_t>(2));
const auto& var_a2 = *block_a.GetVar<VarDescView>(0);
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
const auto& op_a0 = *block_a.GetOp<OpDescView>(0);
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(static_cast<std::vector<std::string>>(
op_a0.GetAttr<std::vector<std::string>>("Attr2")) ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
const auto& block_b = *program.GetBlock<BlockDescView>(1);
CHECK_EQ(block_b.OpsSize(), static_cast<size_t>(2));
CHECK_EQ(block_b.VarsSize(), static_cast<size_t>(1));
const auto& op_b0 = *block_b.GetOp<OpDescView>(1);
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
} // namespace test
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -93,9 +93,14 @@ class VarDesc : public VarDescAPI {
Type GetType() const override { return ConvertVarType(type_->type); }
void SetType(Type type) override {
CHECK(type == VarDescAPI::Type::LOD_TENSOR);
type_->type = ConvertVarType(type);
void SetType(Type type) override { type_->type = ConvertVarType(type); }
void SetDataType(Type type) {
type_->lod_tensor->tensor->data_type = ConvertVarType(type);
}
Type GetDataType() const {
return ConvertVarType(type_->lod_tensor->tensor->data_type);
}
bool Persistable() const override { return desc_->persistable; }
......
......@@ -17,6 +17,7 @@
#include <fstream>
#include <limits>
#include <set>
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/core/variable.h"
......@@ -27,6 +28,7 @@
#include "lite/model_parser/naive_buffer/program_desc.h"
#include "lite/model_parser/naive_buffer/var_desc.h"
#ifndef LITE_ON_TINY_PUBLISH
#include "lite/model_parser/flatbuffers/io.h"
#include "lite/model_parser/pb/program_desc.h"
#include "lite/model_parser/pb/var_desc.h"
#endif
......@@ -592,7 +594,54 @@ void SaveModelNaive(const std::string &model_dir,
LOG(INFO) << "Save naive buffer model in '" << model_dir
<< ".nb' successfully";
}
#endif
/* ---------- Flatbuffers ---------- */
void SaveModelFbs(const std::string &model_dir,
const Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
/* 1. Save model to model.fbs */
const std::string prog_path = model_dir + "/model.fbs";
fbs::ProgramDesc fbs_prog;
TransformProgramDescCppToAny(cpp_prog, &fbs_prog);
fbs::SaveFile(prog_path, fbs_prog.data(), fbs_prog.buf_size());
/* 2. Get param names from cpp::ProgramDesc */
auto &main_block_desc = *cpp_prog.GetBlock<cpp::BlockDesc>(0);
// set unique_var_names to avoid saving shared params repeatedly
std::set<std::string> unique_var_names;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable() ||
unique_var_names.count(var.Name()) > 0)
continue;
unique_var_names.emplace(var.Name());
}
/* 3. Save combined params to params.fbs */
const std::string params_path = model_dir + "/params.fbs";
fbs::CombinedParamsDesc params_prog;
fbs::SetCombinedParamsWithScope(exec_scope, unique_var_names, &params_prog);
fbs::SaveFile(params_path, params_prog.data(), params_prog.buf_size());
}
void LoadModelFbsFromFile(const std::string &filename,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
CHECK(cpp_prog);
CHECK(scope);
/* 1. Save cpp::ProgramDesc with model.fbs */
const std::string prog_path = filename + "/model.fbs";
fbs::ProgramDesc program(fbs::LoadFile(prog_path));
TransformProgramDescAnyToCpp(program, cpp_prog);
/* 2. Save scope with params.fbs */
const std::string params_path = filename + "/params.fbs";
fbs::CombinedParamsDesc params(fbs::LoadFile(params_path));
fbs::SetScopeWithCombinedParams(scope, params);
}
#endif // LITE_ON_TINY_PUBLISH
template <typename T>
void SetTensorDataNaive(T *out, size_t size, const std::vector<T> &src) {
......
......@@ -88,7 +88,15 @@ void SaveModelNaive(const std::string& model_dir,
const Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog,
bool combined = true);
#endif
void SaveModelFbs(const std::string& model_dir,
const Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog);
void LoadModelFbsFromFile(const std::string& filename,
Scope* scope,
cpp::ProgramDesc* cpp_prog);
#endif // LITE_ON_TINY_PUBLISH
void LoadParamNaive(const std::string& path,
lite::Scope* scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册