未验证 提交 e97d9ad3 编写于 作者: 石晓伟 提交者: GitHub

add fbs::ProgramDesc, test=develop (#4074)

上级 ba66bc55
......@@ -8,8 +8,9 @@ endfunction()
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS fbs_headers)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS fbs_headers)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc)
lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_block_desc fbs_op_desc fbs_var_desc)
lite_fbs_library(fbs_param_desc SRCS param_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc scope)
lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc)
lite_cc_test(test_fbs_io SRCS io_test.cc DEPS fbs_io)
lite_cc_test(test_program_desc SRCS program_desc_test.cc DEPS fbs_program_desc)
......@@ -92,7 +92,11 @@ class BlockDescView : public BlockDescAPI {
class BlockDesc : public BlockDescAPI {
public:
BlockDesc() : owned_(true), desc_(new proto::BlockDescT()) {}
explicit BlockDesc(proto::BlockDescT* desc) : desc_(desc) { CHECK(desc_); }
explicit BlockDesc(proto::BlockDescT* desc) : desc_(desc) {
CHECK(desc_);
SyncVars();
SyncOps();
}
int32_t Idx() const override { return desc_->idx; }
......
......@@ -91,23 +91,30 @@ GET_ATTRS_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTR_IMPL
#undef GET_ATTRS_IMPL
#define ATTR_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
return (*GetKeyIterator(name, desc_->attrs))->fb_f__; \
} \
template <> \
void OpDesc::SetAttr(const std::string& name, const T& v) { \
(*GetKeyIterator(name, desc_->attrs))->fb_f__ = v; \
#define ATTR_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
return (*GetKeyIterator(name, desc_->attrs))->fb_f__; \
} \
template <> \
void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \
auto& p = *InsertPair(name, \
std::move(std::unique_ptr<proto::OpDesc_::AttrT>( \
new proto::OpDesc_::AttrT())), \
&(desc_->attrs)); \
p->fb_f__ = v; \
SetKey(name, &p); \
}
ATTR_IMPL(int32_t, i);
ATTR_IMPL(int16_t, block_idx);
ATTR_IMPL(float, f);
ATTR_IMPL(bool, b);
ATTR_IMPL(int64_t, l);
ATTR_IMPL(std::string, s);
ATTR_IMPL(std::vector<int>, ints);
ATTR_IMPL(std::vector<float>, floats);
ATTR_IMPL(std::vector<int64_t>, longs);
ATTR_IMPL(std::vector<std::string>, strings);
#undef GET_ATTRS_IMPL
} // namespace fbs
......
......@@ -32,6 +32,20 @@ BlockDescView const* ProgramDescView::GetBlock<BlockDescView>(
return &blocks_[idx];
}
template <>
proto::BlockDescT* ProgramDesc::GetBlock<proto::BlockDescT>(int32_t idx) {
CHECK_LT(idx, BlocksSize()) << "idx >= vars.size()";
return blocks_[idx].raw_desc();
}
template <>
proto::BlockDescT* ProgramDesc::AddBlock<proto::BlockDescT>() {
desc_.blocks.push_back(
std::unique_ptr<proto::BlockDescT>(new proto::BlockDescT));
SyncBlocks();
return blocks_.back().raw_desc();
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -97,6 +97,79 @@ class ProgramDescView : public ProgramDescAPI {
}
};
class ProgramDesc : public ProgramDescAPI {
public:
ProgramDesc() = default;
explicit ProgramDesc(const std::vector<char>& buf) {
const auto* raw_buf = proto::GetProgramDesc(buf.data());
raw_buf->UnPackTo(&desc_);
SyncBlocks();
}
size_t BlocksSize() const override { return desc_.blocks.size(); }
void ClearBlocks() override {
desc_.blocks.clear();
SyncBlocks();
}
template <typename T>
T* GetBlock(int32_t idx);
template <typename T>
T* AddBlock();
bool HasVersion() const override { return desc_.version.get(); }
int64_t Version() const override {
if (!HasVersion()) {
return -1;
}
return desc_.version->version;
}
void SetVersion(int64_t version_in) override {
if (!HasVersion()) {
desc_.version.reset(new fbs::proto::VersionT());
}
desc_.version->version = version_in;
}
const void* data() {
SyncBuffer();
return buf_.data();
}
size_t buf_size() {
SyncBuffer();
return buf_.size();
}
private:
void SyncBlocks() {
blocks_.resize(desc_.blocks.size());
for (size_t i = 0; i < desc_.blocks.size(); ++i) {
if (blocks_[i].raw_desc() != desc_.blocks[i].get()) {
blocks_[i] = BlockDesc(desc_.blocks[i].get());
}
}
}
void SyncBuffer() {
fbb_.Reset();
flatbuffers::Offset<proto::ProgramDesc> desc =
proto::ProgramDesc::Pack(fbb_, &desc_);
fbb_.Finish(desc);
buf_ = fbb_.Release();
}
flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_;
proto::ProgramDescT desc_;
std::vector<BlockDesc> blocks_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/program_desc.h"
#include <gtest/gtest.h>
#include <string>
namespace paddle {
namespace lite {
namespace fbs {
namespace {
std::vector<char> GenerateProgramCache() {
/* --------- Set Program --------- */
ProgramDesc program;
program.SetVersion(1000600);
/* --------- Set Block A --------- */
BlockDesc block_a(program.AddBlock<proto::BlockDescT>());
VarDesc var_a2(block_a.AddVar<proto::VarDescT>());
var_a2.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a2.SetName("var_a2");
var_a2.SetShape({2, 2, 1});
VarDesc var_a0(block_a.AddVar<proto::VarDescT>());
var_a0.SetType(paddle::lite::VarDataType::LOD_TENSOR);
var_a0.SetName("var_a0");
var_a0.SetShape({1, 2});
OpDesc op_a0(block_a.AddOp<proto::OpDescT>());
op_a0.SetType("Type");
op_a0.SetInput("X", {"var_a0"});
op_a0.SetOutput("Y0", {"var_a0", "var_a1"});
op_a0.SetOutput("Y1", {"var_a2"});
op_a0.SetAttr<std::string>("Attr5", "attr_5");
op_a0.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_a0.SetAttr<float>("Attr1", 0.98f);
op_a0.SetAttr<int32_t>("Attr0", 16);
/* --------- Set Block B --------- */
BlockDesc block_b(program.AddBlock<proto::BlockDescT>());
VarDesc var_b0(block_b.AddVar<proto::VarDescT>());
var_b0.SetName("var_b0");
var_b0.SetShape({-1, 1});
OpDesc op_b0(block_b.AddOp<proto::OpDescT>());
op_b0.SetType("Type0");
op_b0.SetInput("X", {"var_b0"});
op_b0.SetOutput("Y1", {"var_b0"});
op_b0.SetAttr<std::string>("Attr5", "attr_5");
OpDesc op_b1(block_b.AddOp<proto::OpDescT>());
op_b1.SetType("Type1");
op_b1.SetInput("X", {"var_b0"});
op_b1.SetOutput("Y1", {"var_b0"});
op_b1.SetAttr<std::string>("Attr5", "attr_5");
op_b1.SetAttr<std::vector<std::string>>("Attr2", {"attr_2"});
op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */
std::vector<char> cache;
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
}
} // namespace
TEST(ProgramDesc, LoadTest) {
ProgramDesc program(GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
auto block_a = BlockDesc(program.GetBlock<proto::BlockDescT>(0));
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
auto var_a2 = VarDesc(block_a.GetVar<proto::VarDescT>(0));
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
auto op_a0 = OpDesc(block_a.GetOp<proto::OpDescT>(0));
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(op_a0.GetAttr<std::vector<std::string>>("Attr2") ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
auto block_b = BlockDesc(program.GetBlock<proto::BlockDescT>(1));
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
auto op_b0 = OpDesc(block_b.GetOp<proto::OpDescT>(1));
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
TEST(ProgramDescView, LoadTest) {
const ProgramDescView program(GenerateProgramCache());
CHECK_EQ(program.Version(), 1000600);
CHECK_EQ(program.BlocksSize(), static_cast<size_t>(2));
/* --------- Check Block A --------- */
const auto& block_a = *program.GetBlock<BlockDescView>(0);
CHECK_EQ(block_a.OpsSize(), 1);
CHECK_EQ(block_a.VarsSize(), 2);
const auto& var_a2 = *block_a.GetVar<VarDescView>(0);
CHECK(var_a2.GetShape() == std::vector<int64_t>({2, 2, 1}));
const auto& op_a0 = *block_a.GetOp<OpDescView>(0);
CHECK_EQ(op_a0.Type(), std::string("Type"));
CHECK(op_a0.Input("X") == std::vector<std::string>({"var_a0"}));
CHECK(op_a0.Output("Y0") == std::vector<std::string>({"var_a0", "var_a1"}));
CHECK(op_a0.Output("Y1") == std::vector<std::string>({"var_a2"}));
CHECK_EQ(op_a0.GetAttr<float>("Attr1"), 0.98f);
CHECK_EQ(op_a0.GetAttr<int32_t>("Attr0"), 16);
CHECK_EQ(op_a0.GetAttr<std::string>("Attr5"), std::string("attr_5"));
CHECK(static_cast<std::vector<std::string>>(
op_a0.GetAttr<std::vector<std::string>>("Attr2")) ==
std::vector<std::string>({"attr_2"}));
/* --------- Check Block B --------- */
const auto& block_b = *program.GetBlock<BlockDescView>(1);
CHECK_EQ(block_b.OpsSize(), 2);
CHECK_EQ(block_b.VarsSize(), 1);
const auto& op_b0 = *block_b.GetOp<OpDescView>(1);
CHECK_EQ(op_b0.GetAttr<bool>("Attr1"), true);
CHECK_EQ(op_b0.HasAttr("Attr4"), false);
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -147,11 +147,19 @@ inline proto::AttrType ConvertAttrType(lite::OpAttrType type) {
template <typename FlatbuffersMapT, typename KeyT = std::string>
KeyT GetKey(const std::unique_ptr<FlatbuffersMapT>& object);
#define GET_KEY_INSTANCE(type, key, key_type) \
template <> \
inline key_type GetKey<proto::type>( \
const std::unique_ptr<proto::type>& object) { \
return object->key; \
template <typename FlatbuffersMapT, typename KeyT = std::string>
void SetKey(const KeyT& key, std::unique_ptr<FlatbuffersMapT>* object);
#define GET_KEY_INSTANCE(type, key, key_type) \
template <> \
inline key_type GetKey<proto::type>( \
const std::unique_ptr<proto::type>& object) { \
return object->key; \
} \
template <> \
inline void SetKey<proto::type>(const key_type& key_in, \
std::unique_ptr<proto::type>* object) { \
(*object)->key = key_in; \
}
GET_KEY_INSTANCE(OpDesc_::VarT, parameter, std::string);
GET_KEY_INSTANCE(OpDesc_::AttrT, name, std::string);
......@@ -182,19 +190,20 @@ typename std::vector<std::unique_ptr<MapT>>::const_iterator GetKeyIterator(
const KeyT& key, const std::vector<std::unique_ptr<MapT>>& vector) {
auto iter =
std::lower_bound(vector.begin(), vector.end(), key, CompareFunc());
CHECK(GetKey(*iter) == key);
CHECK_EQ(GetKey(*iter), key);
return iter;
}
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
void InsertPair(const KeyT& key,
std::unique_ptr<MapT>&& val,
std::vector<std::unique_ptr<MapT>>* vector) {
typename std::vector<std::unique_ptr<MapT>>::iterator InsertPair(
const KeyT& key,
std::unique_ptr<MapT>&& val,
std::vector<std::unique_ptr<MapT>>* vector) {
auto iter =
std::lower_bound(vector->begin(), vector->end(), key, CompareFunc());
vector->insert(iter, std::forward<std::unique_ptr<MapT>>(val));
return vector->insert(iter, std::forward<std::unique_ptr<MapT>>(val));
}
template <typename MapT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册