未验证 提交 be13a60a 编写于 作者: 石晓伟 提交者: GitHub

Add some write-apis for flatbuffers, test=develop (#4065)

* add view suffix to read-only classes for flatbuffers, test=develop

* add some write-apis for flatbuffers, test=develop
上级 4c49f876
......@@ -13,35 +13,62 @@
// limitations under the License.
#include "lite/model_parser/flatbuffers/block_desc.h"
#include <memory>
namespace paddle {
namespace lite {
namespace fbs {
template <>
proto::VarDesc const* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) const {
proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return desc_->vars()->Get(idx);
}
template <>
proto::OpDesc const* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) const {
proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return desc_->ops()->Get(idx);
}
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
}
template <>
proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return vars_[idx].raw_desc();
}
template <>
proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() {
desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT));
SyncVars();
return vars_.back().raw_desc();
}
template <>
proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= vars.size()";
return ops_[idx].raw_desc();
}
template <>
proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() {
desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT));
SyncOps();
return ops_.back().raw_desc();
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -25,17 +25,17 @@ namespace paddle {
namespace lite {
namespace fbs {
class BlockDesc : public BlockDescAPI {
class BlockDescView : public BlockDescAPI {
public:
explicit BlockDesc(proto::BlockDesc const* desc) : desc_(desc) {
explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_);
vars_.reserve(VarsSize());
ops_.reserve(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDesc(desc_->vars()->Get(idx)));
vars_.push_back(VarDescView(desc_->vars()->Get(idx)));
}
for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDesc(desc_->ops()->Get(idx)));
ops_.push_back(OpDescView(desc_->ops()->Get(idx)));
}
}
......@@ -69,26 +69,103 @@ class BlockDesc : public BlockDescAPI {
return nullptr;
}
const std::vector<VarDesc>& GetVars() const { return vars_; }
const std::vector<VarDescView>& GetVars() const { return vars_; }
int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx();
}
BlockDesc() { NotImplemented(); }
BlockDescView() { NotImplemented(); }
private:
proto::BlockDesc const* desc_; // not_own
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
std::vector<VarDescView> vars_;
std::vector<OpDescView> ops_;
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of BlockDesc is temporarily "
LOG(FATAL) << "The additional interfaces of BlockDescView is temporarily "
"unavailable in read-only mode.";
}
};
class BlockDesc : public BlockDescAPI {
public:
BlockDesc() : owned_(true), desc_(new proto::BlockDescT()) {}
explicit BlockDesc(proto::BlockDescT* desc) : desc_(desc) { CHECK(desc_); }
int32_t Idx() const override { return desc_->idx; }
void SetIdx(int32_t idx) override { desc_->idx = idx; }
int32_t ParentIdx() const override { return desc_->parent_idx; }
void SetParentIdx(int32_t idx) override { desc_->parent_idx = idx; }
size_t VarsSize() const override { return desc_->vars.size(); }
void ClearVars() override {
desc_->vars.clear();
SyncVars();
}
size_t OpsSize() const override { return desc_->ops.size(); }
void ClearOps() override {
desc_->ops.clear();
SyncOps();
}
int32_t ForwardBlockIdx() const override { return desc_->forward_block_idx; }
void SetForwardBlockIdx(int32_t idx_in) override {
desc_->forward_block_idx = idx_in;
}
proto::BlockDescT* raw_desc() { return desc_; }
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T* AddVar();
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T* AddOp();
~BlockDesc() {
if (owned_) {
delete desc_;
}
}
private:
void SyncVars() {
vars_.resize(desc_->vars.size());
for (size_t i = 0; i < desc_->vars.size(); ++i) {
if (vars_[i].raw_desc() != desc_->vars[i].get()) {
vars_[i] = VarDesc(desc_->vars[i].get());
}
}
}
void SyncOps() {
ops_.resize(desc_->ops.size());
for (size_t i = 0; i < desc_->ops.size(); ++i) {
if (ops_[i].raw_desc() != desc_->ops[i].get()) {
ops_[i] = OpDesc(desc_->ops[i].get());
}
}
}
bool owned_{false};
proto::BlockDescT* desc_{nullptr};
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -19,7 +19,7 @@ namespace lite {
namespace fbs {
template <>
std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
if (!it->s()) {
return std::string();
......@@ -28,7 +28,7 @@ std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
}
template <>
std::string OpDesc::GetAttr<std::string>(size_t idx) const {
std::string OpDescView::GetAttr<std::string>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
if (!it->s()) {
return std::string();
......@@ -38,43 +38,43 @@ std::string OpDesc::GetAttr<std::string>(size_t idx) const {
template <>
lite::VectorView<std::string, Flatbuffers>
OpDesc::GetAttr<std::vector<std::string>>(const std::string& name) const {
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
CHECK(it) << "Attr " << name << "does not exist.";
return VectorView<std::string>(it->strings());
}
template <>
VectorView<std::string, Flatbuffers> OpDesc::GetAttr<std::vector<std::string>>(
size_t idx) const {
VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
CHECK(it) << "Attr " << idx << "does not exist.";
return VectorView<std::string>(it->strings());
}
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return it->fb_f__(); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return it->fb_f__(); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
}
#define GET_ATTRS_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
......@@ -88,6 +88,27 @@ GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTR_IMPL
#undef GET_ATTRS_IMPL
#define ATTR_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
return (*GetKeyIterator(name, desc_->attrs))->fb_f__; \
} \
template <> \
void OpDesc::SetAttr(const std::string& name, const T& v) { \
(*GetKeyIterator(name, desc_->attrs))->fb_f__ = v; \
}
ATTR_IMPL(int32_t, i);
ATTR_IMPL(int16_t, block_idx);
ATTR_IMPL(float, f);
ATTR_IMPL(bool, b);
ATTR_IMPL(int64_t, l);
ATTR_IMPL(std::vector<int>, ints);
ATTR_IMPL(std::vector<float>, floats);
ATTR_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTRS_IMPL
} // namespace fbs
} // namespace lite
......
......@@ -17,6 +17,7 @@
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/base/op_desc.h"
......@@ -29,9 +30,9 @@ namespace paddle {
namespace lite {
namespace fbs {
class OpDesc : public OpDescAPI {
class OpDescView : public OpDescAPI {
public:
explicit OpDesc(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
explicit OpDescView(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); }
......@@ -137,7 +138,7 @@ class OpDesc : public OpDescAPI {
// caused by different building options.
public:
OpDesc() { NotImplemented(); }
OpDescView() { NotImplemented(); }
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
......@@ -184,7 +185,7 @@ class OpDesc : public OpDescAPI {
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of OpDesc is temporarily "
LOG(FATAL) << "The additional interfaces of OpDescView is temporarily "
"unavailable in read-only mode.";
}
std::string type_;
......@@ -194,6 +195,93 @@ class OpDesc : public OpDescAPI {
std::map<std::string, AttrType> attr_types_;
};
class OpDesc : public OpDescAPI {
public:
OpDesc() : owned_(true), desc_(new proto::OpDescT()) {}
explicit OpDesc(proto::OpDescT* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type; }
void SetType(const std::string& type) override { desc_->type = type; }
std::vector<std::string> Input(const std::string& param) const override {
return (*GetKeyIterator(param, desc_->inputs))->arguments;
}
std::vector<std::string> InputArgumentNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& input : desc_->inputs) {
tmp.push_back(input->parameter);
}
return tmp;
}
void SetInput(const std::string& param,
const std::vector<std::string>& args) override {
std::unique_ptr<proto::OpDesc_::VarT> var(new proto::OpDesc_::VarT);
var->parameter = param;
var->arguments = args;
InsertPair(param, std::move(var), &desc_->inputs);
}
std::vector<std::string> Output(const std::string& param) const override {
return (*GetKeyIterator(param, desc_->outputs))->arguments;
}
std::vector<std::string> OutputArgumentNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& output : desc_->outputs) {
tmp.push_back(output->parameter);
}
return tmp;
}
void SetOutput(const std::string& param,
const std::vector<std::string>& args) override {
std::unique_ptr<proto::OpDesc_::VarT> var(new proto::OpDesc_::VarT);
var->parameter = param;
var->arguments = args;
InsertPair(param, std::move(var), &desc_->outputs);
}
bool HasAttr(const std::string& name) const override {
return HasKey(name, desc_->attrs);
}
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
return ConvertAttrType((*GetKeyIterator(name, desc_->attrs))->type);
}
std::vector<std::string> AttrNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& attr : desc_->attrs) {
tmp.push_back(attr->name);
}
return tmp;
}
template <typename T>
void SetAttr(const std::string& name, const T& v);
template <typename T>
T GetAttr(const std::string& name) const;
proto::OpDescT* raw_desc() { return desc_; }
~OpDesc() {
if (owned_) {
delete desc_;
}
}
private:
bool owned_{false};
proto::OpDescT* desc_{nullptr};
};
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -19,14 +19,15 @@ namespace lite {
namespace fbs {
template <>
proto::BlockDesc const* ProgramDesc::GetBlock<proto::BlockDesc>(
proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return desc_->blocks()->Get(idx);
}
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
BlockDescView const* ProgramDescView::GetBlock<BlockDescView>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
}
......
......@@ -26,11 +26,11 @@ namespace paddle {
namespace lite {
namespace fbs {
class ProgramDesc : public ProgramDescAPI {
class ProgramDescView : public ProgramDescAPI {
public:
ProgramDesc() = default;
explicit ProgramDesc(const std::vector<char>& buf) { Init(buf); }
explicit ProgramDesc(std::vector<char>&& buf) {
ProgramDescView() = default;
explicit ProgramDescView(const std::vector<char>& buf) { Init(buf); }
explicit ProgramDescView(std::vector<char>&& buf) {
Init(std::forward<std::vector<char>>(buf));
}
......@@ -50,11 +50,11 @@ class ProgramDesc : public ProgramDescAPI {
desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx)));
blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx)));
}
}
void CopyFrom(const ProgramDesc& other) {
void CopyFrom(const ProgramDescView& other) {
buf_ = other.buf();
Init(buf_);
}
......@@ -70,7 +70,7 @@ class ProgramDesc : public ProgramDescAPI {
return nullptr;
}
const std::vector<BlockDesc>& GetBlocks() const { return blocks_; }
const std::vector<BlockDescView>& GetBlocks() const { return blocks_; }
bool HasVersion() const override { return desc_->version() != nullptr; }
......@@ -86,13 +86,13 @@ class ProgramDesc : public ProgramDescAPI {
private:
proto::ProgramDesc const* desc_;
std::vector<char> buf_;
std::vector<BlockDesc> blocks_;
std::vector<BlockDescView> blocks_;
private:
ProgramDesc& operator=(const ProgramDesc&) = delete;
ProgramDesc(const ProgramDesc&) = delete;
ProgramDescView& operator=(const ProgramDescView&) = delete;
ProgramDescView(const ProgramDescView&) = delete;
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily "
LOG(FATAL) << "The additional interfaces of ProgramDescView is temporarily "
"unavailable in read-only mode.";
}
};
......
......@@ -14,6 +14,11 @@
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/base/traits.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
......@@ -139,6 +144,71 @@ inline proto::AttrType ConvertAttrType(lite::OpAttrType type) {
#undef CASE
}
template <typename FlatbuffersMapT, typename KeyT = std::string>
KeyT GetKey(const std::unique_ptr<FlatbuffersMapT>& object);
#define GET_KEY_INSTANCE(type, key, key_type) \
template <> \
inline key_type GetKey<proto::type>( \
const std::unique_ptr<proto::type>& object) { \
return object->key; \
}
GET_KEY_INSTANCE(OpDesc_::VarT, parameter, std::string);
GET_KEY_INSTANCE(OpDesc_::AttrT, name, std::string);
#undef GET_KEY_INSTANCE
template <typename MapT, typename KeyT = std::string>
struct CompareLessThanKey {
bool operator()(const std::unique_ptr<MapT>& lhs, const KeyT& rhs) {
return GetKey(lhs) < rhs;
}
bool operator()(const KeyT& lhs, const std::unique_ptr<MapT>& rhs) {
return lhs < GetKey(rhs);
}
};
template <typename MapT>
struct CompareLessThan {
bool operator()(const std::unique_ptr<MapT>& lhs,
const std::unique_ptr<MapT>& rhs) {
return GetKey(lhs) < GetKey(rhs);
}
};
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
typename std::vector<std::unique_ptr<MapT>>::const_iterator GetKeyIterator(
const KeyT& key, const std::vector<std::unique_ptr<MapT>>& vector) {
auto iter =
std::lower_bound(vector.begin(), vector.end(), key, CompareFunc());
CHECK(GetKey(*iter) == key);
return iter;
}
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
void InsertPair(const KeyT& key,
std::unique_ptr<MapT>&& val,
std::vector<std::unique_ptr<MapT>>* vector) {
auto iter =
std::lower_bound(vector->begin(), vector->end(), key, CompareFunc());
vector->insert(iter, std::forward<std::unique_ptr<MapT>>(val));
}
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
bool HasKey(const KeyT& key, const std::vector<std::unique_ptr<MapT>>& vector) {
return std::binary_search(vector.begin(), vector.end(), key, CompareFunc());
}
template <typename MapT, typename CompareFunc = CompareLessThan<MapT>>
void Sort(std::vector<std::unique_ptr<MapT>>* vector) {
std::sort(vector->begin(), vector->end(), CompareFunc());
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -26,9 +26,9 @@ namespace paddle {
namespace lite {
namespace fbs {
class VarDesc : public VarDescAPI {
class VarDescView : public VarDescAPI {
public:
explicit VarDesc(proto::VarDesc const* desc) : desc_(desc) {}
explicit VarDescView(proto::VarDesc const* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); }
......@@ -66,18 +66,79 @@ class VarDesc : public VarDescAPI {
// caused by different building options.
public:
VarDesc() { NotImplemented(); }
VarDescView() { NotImplemented(); }
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily "
LOG(FATAL) << "The additional interfaces of VarDescView is temporarily "
"unavailable in read-only mode.";
}
std::vector<int64_t> shape_;
};
class VarDesc : public VarDescAPI {
public:
VarDesc() : owned_(true), desc_(new proto::VarDescT()) {}
explicit VarDesc(proto::VarDescT* desc) : desc_(desc) {
CHECK(desc_);
InitType();
}
std::string Name() const override { return desc_->name; }
void SetName(std::string name) override { desc_->name = name; }
Type GetType() const override { return ConvertVarType(type_->type); }
void SetType(Type type) override {
CHECK(type == VarDescAPI::Type::LOD_TENSOR);
type_->type = ConvertVarType(type);
}
bool Persistable() const override { return desc_->persistable; }
void SetPersistable(bool persistable) override {
desc_->persistable = persistable;
}
std::vector<int64_t> GetShape() const override {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
return type_->lod_tensor->tensor->dims;
}
void SetShape(const std::vector<int64_t>& dims) override {
type_->lod_tensor->tensor->dims = dims;
}
proto::VarDescT* raw_desc() { return desc_; }
~VarDesc() {
if (owned_) {
delete desc_;
}
}
private:
void InitType() {
if (!desc_->type) {
desc_->type = std::unique_ptr<proto::VarTypeT>(new proto::VarTypeT());
desc_->type->lod_tensor =
std::unique_ptr<proto::VarType_::LoDTensorDescT>(
new proto::VarType_::LoDTensorDescT());
desc_->type->lod_tensor->tensor =
std::unique_ptr<proto::VarType_::TensorDescT>(
new proto::VarType_::TensorDescT());
}
type_ = desc_->type.get();
}
bool owned_{false};
proto::VarDescT* desc_{nullptr};
paddle::lite::fbs::proto::VarTypeT* type_{nullptr};
};
} // namespace fbs
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册