未验证 提交 be13a60a 编写于 作者: 石晓伟 提交者: GitHub

Add some write-apis for flatbuffers, test=develop (#4065)

* add view suffix to read-only classes for flatbuffers, test=develop

* add some write-apis for flatbuffers, test=develop
上级 4c49f876
...@@ -13,35 +13,62 @@ ...@@ -13,35 +13,62 @@
// limitations under the License. // limitations under the License.
#include "lite/model_parser/flatbuffers/block_desc.h" #include "lite/model_parser/flatbuffers/block_desc.h"
#include <memory>
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
template <> template <>
proto::VarDesc const* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) const { proto::VarDesc const* BlockDescView::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return desc_->vars()->Get(idx); return desc_->vars()->Get(idx);
} }
template <> template <>
proto::OpDesc const* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) const { proto::OpDesc const* BlockDescView::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return desc_->ops()->Get(idx); return desc_->ops()->Get(idx);
} }
template <> template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const { VarDescView const* BlockDescView::GetVar<VarDescView>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx]; return &vars_[idx];
} }
template <> template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const { OpDescView const* BlockDescView::GetOp<OpDescView>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx]; return &ops_[idx];
} }
template <>
proto::VarDescT* BlockDesc::GetVar<proto::VarDescT>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return vars_[idx].raw_desc();
}
template <>
proto::VarDescT* BlockDesc::AddVar<proto::VarDescT>() {
desc_->vars.push_back(std::unique_ptr<proto::VarDescT>(new proto::VarDescT));
SyncVars();
return vars_.back().raw_desc();
}
template <>
proto::OpDescT* BlockDesc::GetOp<proto::OpDescT>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= vars.size()";
return ops_[idx].raw_desc();
}
template <>
proto::OpDescT* BlockDesc::AddOp<proto::OpDescT>() {
desc_->ops.push_back(std::unique_ptr<proto::OpDescT>(new proto::OpDescT));
SyncOps();
return ops_.back().raw_desc();
}
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -25,17 +25,17 @@ namespace paddle { ...@@ -25,17 +25,17 @@ namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
class BlockDesc : public BlockDescAPI { class BlockDescView : public BlockDescAPI {
public: public:
explicit BlockDesc(proto::BlockDesc const* desc) : desc_(desc) { explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_); CHECK(desc_);
vars_.reserve(VarsSize()); vars_.reserve(VarsSize());
ops_.reserve(OpsSize()); ops_.reserve(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) { for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDesc(desc_->vars()->Get(idx))); vars_.push_back(VarDescView(desc_->vars()->Get(idx)));
} }
for (size_t idx = 0; idx < OpsSize(); ++idx) { for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDesc(desc_->ops()->Get(idx))); ops_.push_back(OpDescView(desc_->ops()->Get(idx)));
} }
} }
...@@ -69,26 +69,103 @@ class BlockDesc : public BlockDescAPI { ...@@ -69,26 +69,103 @@ class BlockDesc : public BlockDescAPI {
return nullptr; return nullptr;
} }
const std::vector<VarDesc>& GetVars() const { return vars_; } const std::vector<VarDescView>& GetVars() const { return vars_; }
int32_t ForwardBlockIdx() const override { int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx(); return desc_->forward_block_idx();
} }
BlockDesc() { NotImplemented(); } BlockDescView() { NotImplemented(); }
private: private:
proto::BlockDesc const* desc_; // not_own proto::BlockDesc const* desc_; // not_own
std::vector<VarDesc> vars_; std::vector<VarDescView> vars_;
std::vector<OpDesc> ops_; std::vector<OpDescView> ops_;
private: private:
void NotImplemented() const { void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of BlockDesc is temporarily " LOG(FATAL) << "The additional interfaces of BlockDescView is temporarily "
"unavailable in read-only mode."; "unavailable in read-only mode.";
} }
}; };
class BlockDesc : public BlockDescAPI {
public:
BlockDesc() : owned_(true), desc_(new proto::BlockDescT()) {}
explicit BlockDesc(proto::BlockDescT* desc) : desc_(desc) { CHECK(desc_); }
int32_t Idx() const override { return desc_->idx; }
void SetIdx(int32_t idx) override { desc_->idx = idx; }
int32_t ParentIdx() const override { return desc_->parent_idx; }
void SetParentIdx(int32_t idx) override { desc_->parent_idx = idx; }
size_t VarsSize() const override { return desc_->vars.size(); }
void ClearVars() override {
desc_->vars.clear();
SyncVars();
}
size_t OpsSize() const override { return desc_->ops.size(); }
void ClearOps() override {
desc_->ops.clear();
SyncOps();
}
int32_t ForwardBlockIdx() const override { return desc_->forward_block_idx; }
void SetForwardBlockIdx(int32_t idx_in) override {
desc_->forward_block_idx = idx_in;
}
proto::BlockDescT* raw_desc() { return desc_; }
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T* AddVar();
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T* AddOp();
~BlockDesc() {
if (owned_) {
delete desc_;
}
}
private:
void SyncVars() {
vars_.resize(desc_->vars.size());
for (size_t i = 0; i < desc_->vars.size(); ++i) {
if (vars_[i].raw_desc() != desc_->vars[i].get()) {
vars_[i] = VarDesc(desc_->vars[i].get());
}
}
}
void SyncOps() {
ops_.resize(desc_->ops.size());
for (size_t i = 0; i < desc_->ops.size(); ++i) {
if (ops_[i].raw_desc() != desc_->ops[i].get()) {
ops_[i] = OpDesc(desc_->ops[i].get());
}
}
}
bool owned_{false};
proto::BlockDescT* desc_{nullptr};
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
};
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -19,7 +19,7 @@ namespace lite { ...@@ -19,7 +19,7 @@ namespace lite {
namespace fbs { namespace fbs {
template <> template <>
std::string OpDesc::GetAttr<std::string>(const std::string& name) const { std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); const auto& it = desc_->attrs()->LookupByKey(name.c_str());
if (!it->s()) { if (!it->s()) {
return std::string(); return std::string();
...@@ -28,7 +28,7 @@ std::string OpDesc::GetAttr<std::string>(const std::string& name) const { ...@@ -28,7 +28,7 @@ std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
} }
template <> template <>
std::string OpDesc::GetAttr<std::string>(size_t idx) const { std::string OpDescView::GetAttr<std::string>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx); const auto& it = desc_->attrs()->Get(idx);
if (!it->s()) { if (!it->s()) {
return std::string(); return std::string();
...@@ -38,43 +38,43 @@ std::string OpDesc::GetAttr<std::string>(size_t idx) const { ...@@ -38,43 +38,43 @@ std::string OpDesc::GetAttr<std::string>(size_t idx) const {
template <> template <>
lite::VectorView<std::string, Flatbuffers> lite::VectorView<std::string, Flatbuffers>
OpDesc::GetAttr<std::vector<std::string>>(const std::string& name) const { OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); const auto& it = desc_->attrs()->LookupByKey(name.c_str());
CHECK(it) << "Attr " << name << "does not exist."; CHECK(it) << "Attr " << name << "does not exist.";
return VectorView<std::string>(it->strings()); return VectorView<std::string>(it->strings());
} }
template <> template <>
VectorView<std::string, Flatbuffers> OpDesc::GetAttr<std::vector<std::string>>( VectorView<std::string, Flatbuffers>
size_t idx) const { OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx); const auto& it = desc_->attrs()->Get(idx);
CHECK(it) << "Attr " << idx << "does not exist."; CHECK(it) << "Attr " << idx << "does not exist.";
return VectorView<std::string>(it->strings()); return VectorView<std::string>(it->strings());
} }
#define GET_ATTR_IMPL(T, fb_f__) \ #define GET_ATTR_IMPL(T, fb_f__) \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \ const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return it->fb_f__(); \ return it->fb_f__(); \
} \ } \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \ size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \ const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \ return it->fb_f__(); \
} }
#define GET_ATTRS_IMPL(T, fb_f__) \ #define GET_ATTRS_IMPL(T, fb_f__) \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \ const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \ return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} \ } \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \ size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \ const auto& it = desc_->attrs()->Get(idx); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \ return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
...@@ -88,6 +88,27 @@ GET_ATTR_IMPL(int64_t, l); ...@@ -88,6 +88,27 @@ GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints); GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats); GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs); GET_ATTRS_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTR_IMPL
#undef GET_ATTRS_IMPL
#define ATTR_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
return (*GetKeyIterator(name, desc_->attrs))->fb_f__; \
} \
template <> \
void OpDesc::SetAttr(const std::string& name, const T& v) { \
(*GetKeyIterator(name, desc_->attrs))->fb_f__ = v; \
}
ATTR_IMPL(int32_t, i);
ATTR_IMPL(int16_t, block_idx);
ATTR_IMPL(float, f);
ATTR_IMPL(bool, b);
ATTR_IMPL(int64_t, l);
ATTR_IMPL(std::vector<int>, ints);
ATTR_IMPL(std::vector<float>, floats);
ATTR_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTRS_IMPL
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "lite/model_parser/base/op_desc.h" #include "lite/model_parser/base/op_desc.h"
...@@ -29,9 +30,9 @@ namespace paddle { ...@@ -29,9 +30,9 @@ namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
class OpDesc : public OpDescAPI { class OpDescView : public OpDescAPI {
public: public:
explicit OpDesc(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); } explicit OpDescView(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); } std::string Type() const override { return desc_->type()->str(); }
...@@ -137,7 +138,7 @@ class OpDesc : public OpDescAPI { ...@@ -137,7 +138,7 @@ class OpDesc : public OpDescAPI {
// caused by different building options. // caused by different building options.
public: public:
OpDesc() { NotImplemented(); } OpDescView() { NotImplemented(); }
bool HasInput(const std::string& param) const { bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr; return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
} }
...@@ -184,7 +185,7 @@ class OpDesc : public OpDescAPI { ...@@ -184,7 +185,7 @@ class OpDesc : public OpDescAPI {
private: private:
void NotImplemented() const { void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of OpDesc is temporarily " LOG(FATAL) << "The additional interfaces of OpDescView is temporarily "
"unavailable in read-only mode."; "unavailable in read-only mode.";
} }
std::string type_; std::string type_;
...@@ -194,6 +195,93 @@ class OpDesc : public OpDescAPI { ...@@ -194,6 +195,93 @@ class OpDesc : public OpDescAPI {
std::map<std::string, AttrType> attr_types_; std::map<std::string, AttrType> attr_types_;
}; };
class OpDesc : public OpDescAPI {
public:
OpDesc() : owned_(true), desc_(new proto::OpDescT()) {}
explicit OpDesc(proto::OpDescT* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type; }
void SetType(const std::string& type) override { desc_->type = type; }
std::vector<std::string> Input(const std::string& param) const override {
return (*GetKeyIterator(param, desc_->inputs))->arguments;
}
std::vector<std::string> InputArgumentNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& input : desc_->inputs) {
tmp.push_back(input->parameter);
}
return tmp;
}
void SetInput(const std::string& param,
const std::vector<std::string>& args) override {
std::unique_ptr<proto::OpDesc_::VarT> var(new proto::OpDesc_::VarT);
var->parameter = param;
var->arguments = args;
InsertPair(param, std::move(var), &desc_->inputs);
}
std::vector<std::string> Output(const std::string& param) const override {
return (*GetKeyIterator(param, desc_->outputs))->arguments;
}
std::vector<std::string> OutputArgumentNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& output : desc_->outputs) {
tmp.push_back(output->parameter);
}
return tmp;
}
void SetOutput(const std::string& param,
const std::vector<std::string>& args) override {
std::unique_ptr<proto::OpDesc_::VarT> var(new proto::OpDesc_::VarT);
var->parameter = param;
var->arguments = args;
InsertPair(param, std::move(var), &desc_->outputs);
}
bool HasAttr(const std::string& name) const override {
return HasKey(name, desc_->attrs);
}
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
return ConvertAttrType((*GetKeyIterator(name, desc_->attrs))->type);
}
std::vector<std::string> AttrNames() const override {
VLOG(5) << "This function call is expensive.";
std::vector<std::string> tmp;
for (const auto& attr : desc_->attrs) {
tmp.push_back(attr->name);
}
return tmp;
}
template <typename T>
void SetAttr(const std::string& name, const T& v);
template <typename T>
T GetAttr(const std::string& name) const;
proto::OpDescT* raw_desc() { return desc_; }
~OpDesc() {
if (owned_) {
delete desc_;
}
}
private:
bool owned_{false};
proto::OpDescT* desc_{nullptr};
};
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -19,14 +19,15 @@ namespace lite { ...@@ -19,14 +19,15 @@ namespace lite {
namespace fbs { namespace fbs {
template <> template <>
proto::BlockDesc const* ProgramDesc::GetBlock<proto::BlockDesc>( proto::BlockDesc const* ProgramDescView::GetBlock<proto::BlockDesc>(
int32_t idx) const { int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return desc_->blocks()->Get(idx); return desc_->blocks()->Get(idx);
} }
template <> template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const { BlockDescView const* ProgramDescView::GetBlock<BlockDescView>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx]; return &blocks_[idx];
} }
......
...@@ -26,11 +26,11 @@ namespace paddle { ...@@ -26,11 +26,11 @@ namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
class ProgramDesc : public ProgramDescAPI { class ProgramDescView : public ProgramDescAPI {
public: public:
ProgramDesc() = default; ProgramDescView() = default;
explicit ProgramDesc(const std::vector<char>& buf) { Init(buf); } explicit ProgramDescView(const std::vector<char>& buf) { Init(buf); }
explicit ProgramDesc(std::vector<char>&& buf) { explicit ProgramDescView(std::vector<char>&& buf) {
Init(std::forward<std::vector<char>>(buf)); Init(std::forward<std::vector<char>>(buf));
} }
...@@ -50,11 +50,11 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -50,11 +50,11 @@ class ProgramDesc : public ProgramDescAPI {
desc_ = proto::GetProgramDesc(buf_.data()); desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize()); blocks_.reserve(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) { for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx))); blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx)));
} }
} }
void CopyFrom(const ProgramDesc& other) { void CopyFrom(const ProgramDescView& other) {
buf_ = other.buf(); buf_ = other.buf();
Init(buf_); Init(buf_);
} }
...@@ -70,7 +70,7 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -70,7 +70,7 @@ class ProgramDesc : public ProgramDescAPI {
return nullptr; return nullptr;
} }
const std::vector<BlockDesc>& GetBlocks() const { return blocks_; } const std::vector<BlockDescView>& GetBlocks() const { return blocks_; }
bool HasVersion() const override { return desc_->version() != nullptr; } bool HasVersion() const override { return desc_->version() != nullptr; }
...@@ -86,13 +86,13 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -86,13 +86,13 @@ class ProgramDesc : public ProgramDescAPI {
private: private:
proto::ProgramDesc const* desc_; proto::ProgramDesc const* desc_;
std::vector<char> buf_; std::vector<char> buf_;
std::vector<BlockDesc> blocks_; std::vector<BlockDescView> blocks_;
private: private:
ProgramDesc& operator=(const ProgramDesc&) = delete; ProgramDescView& operator=(const ProgramDescView&) = delete;
ProgramDesc(const ProgramDesc&) = delete; ProgramDescView(const ProgramDescView&) = delete;
void NotImplemented() const { void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily " LOG(FATAL) << "The additional interfaces of ProgramDescView is temporarily "
"unavailable in read-only mode."; "unavailable in read-only mode.";
} }
}; };
......
...@@ -14,6 +14,11 @@ ...@@ -14,6 +14,11 @@
#pragma once #pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/base/traits.h" #include "lite/model_parser/base/traits.h"
#include "lite/model_parser/flatbuffers/framework_generated.h" #include "lite/model_parser/flatbuffers/framework_generated.h"
...@@ -139,6 +144,71 @@ inline proto::AttrType ConvertAttrType(lite::OpAttrType type) { ...@@ -139,6 +144,71 @@ inline proto::AttrType ConvertAttrType(lite::OpAttrType type) {
#undef CASE #undef CASE
} }
template <typename FlatbuffersMapT, typename KeyT = std::string>
KeyT GetKey(const std::unique_ptr<FlatbuffersMapT>& object);
#define GET_KEY_INSTANCE(type, key, key_type) \
template <> \
inline key_type GetKey<proto::type>( \
const std::unique_ptr<proto::type>& object) { \
return object->key; \
}
GET_KEY_INSTANCE(OpDesc_::VarT, parameter, std::string);
GET_KEY_INSTANCE(OpDesc_::AttrT, name, std::string);
#undef GET_KEY_INSTANCE
template <typename MapT, typename KeyT = std::string>
struct CompareLessThanKey {
bool operator()(const std::unique_ptr<MapT>& lhs, const KeyT& rhs) {
return GetKey(lhs) < rhs;
}
bool operator()(const KeyT& lhs, const std::unique_ptr<MapT>& rhs) {
return lhs < GetKey(rhs);
}
};
template <typename MapT>
struct CompareLessThan {
bool operator()(const std::unique_ptr<MapT>& lhs,
const std::unique_ptr<MapT>& rhs) {
return GetKey(lhs) < GetKey(rhs);
}
};
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
typename std::vector<std::unique_ptr<MapT>>::const_iterator GetKeyIterator(
const KeyT& key, const std::vector<std::unique_ptr<MapT>>& vector) {
auto iter =
std::lower_bound(vector.begin(), vector.end(), key, CompareFunc());
CHECK(GetKey(*iter) == key);
return iter;
}
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
void InsertPair(const KeyT& key,
std::unique_ptr<MapT>&& val,
std::vector<std::unique_ptr<MapT>>* vector) {
auto iter =
std::lower_bound(vector->begin(), vector->end(), key, CompareFunc());
vector->insert(iter, std::forward<std::unique_ptr<MapT>>(val));
}
template <typename MapT,
typename KeyT = std::string,
typename CompareFunc = CompareLessThanKey<MapT, KeyT>>
bool HasKey(const KeyT& key, const std::vector<std::unique_ptr<MapT>>& vector) {
return std::binary_search(vector.begin(), vector.end(), key, CompareFunc());
}
template <typename MapT, typename CompareFunc = CompareLessThan<MapT>>
void Sort(std::vector<std::unique_ptr<MapT>>* vector) {
std::sort(vector->begin(), vector->end(), CompareFunc());
}
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -26,9 +26,9 @@ namespace paddle { ...@@ -26,9 +26,9 @@ namespace paddle {
namespace lite { namespace lite {
namespace fbs { namespace fbs {
class VarDesc : public VarDescAPI { class VarDescView : public VarDescAPI {
public: public:
explicit VarDesc(proto::VarDesc const* desc) : desc_(desc) {} explicit VarDescView(proto::VarDesc const* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); } std::string Name() const override { return desc_->name()->str(); }
...@@ -66,18 +66,79 @@ class VarDesc : public VarDescAPI { ...@@ -66,18 +66,79 @@ class VarDesc : public VarDescAPI {
// caused by different building options. // caused by different building options.
public: public:
VarDesc() { NotImplemented(); } VarDescView() { NotImplemented(); }
void SetDataType(Type data_type) { NotImplemented(); } void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); } void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
private: private:
void NotImplemented() const { void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily " LOG(FATAL) << "The additional interfaces of VarDescView is temporarily "
"unavailable in read-only mode."; "unavailable in read-only mode.";
} }
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
}; };
class VarDesc : public VarDescAPI {
public:
VarDesc() : owned_(true), desc_(new proto::VarDescT()) {}
explicit VarDesc(proto::VarDescT* desc) : desc_(desc) {
CHECK(desc_);
InitType();
}
std::string Name() const override { return desc_->name; }
void SetName(std::string name) override { desc_->name = name; }
Type GetType() const override { return ConvertVarType(type_->type); }
void SetType(Type type) override {
CHECK(type == VarDescAPI::Type::LOD_TENSOR);
type_->type = ConvertVarType(type);
}
bool Persistable() const override { return desc_->persistable; }
void SetPersistable(bool persistable) override {
desc_->persistable = persistable;
}
std::vector<int64_t> GetShape() const override {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
return type_->lod_tensor->tensor->dims;
}
void SetShape(const std::vector<int64_t>& dims) override {
type_->lod_tensor->tensor->dims = dims;
}
proto::VarDescT* raw_desc() { return desc_; }
~VarDesc() {
if (owned_) {
delete desc_;
}
}
private:
void InitType() {
if (!desc_->type) {
desc_->type = std::unique_ptr<proto::VarTypeT>(new proto::VarTypeT());
desc_->type->lod_tensor =
std::unique_ptr<proto::VarType_::LoDTensorDescT>(
new proto::VarType_::LoDTensorDescT());
desc_->type->lod_tensor->tensor =
std::unique_ptr<proto::VarType_::TensorDescT>(
new proto::VarType_::TensorDescT());
}
type_ = desc_->type.get();
}
bool owned_{false};
proto::VarDescT* desc_{nullptr};
paddle::lite::fbs::proto::VarTypeT* type_{nullptr};
};
} // namespace fbs } // namespace fbs
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册