未验证 提交 37cb221e 编写于 作者: 石晓伟 提交者: GitHub

align opdesc interfaces, test=develop (#3922)

* align opdesc interfaces, test=develop

* align opdesc interfaces, test=develop
上级 e1e2bf38
...@@ -189,7 +189,7 @@ void OpLite::AttachOutput(const cpp::OpDesc &op_desc, ...@@ -189,7 +189,7 @@ void OpLite::AttachOutput(const cpp::OpDesc &op_desc,
bool OpInfo::GetInputArgname(const std::string &value_name, bool OpInfo::GetInputArgname(const std::string &value_name,
std::string *out) const { std::string *out) const {
for (auto &item : inputs_) { for (auto &item : inputs()) {
auto it = std::find(item.second.begin(), item.second.end(), value_name); auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) { if (it != item.second.end()) {
*out = item.first; *out = item.first;
...@@ -201,7 +201,7 @@ bool OpInfo::GetInputArgname(const std::string &value_name, ...@@ -201,7 +201,7 @@ bool OpInfo::GetInputArgname(const std::string &value_name,
bool OpInfo::GetOutputArgname(const std::string &value_name, bool OpInfo::GetOutputArgname(const std::string &value_name,
std::string *out) const { std::string *out) const {
for (auto &item : outputs_) { for (auto &item : outputs()) {
auto it = std::find(item.second.begin(), item.second.end(), value_name); auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) { if (it != item.second.end()) {
*out = item.first; *out = item.first;
...@@ -212,7 +212,7 @@ bool OpInfo::GetOutputArgname(const std::string &value_name, ...@@ -212,7 +212,7 @@ bool OpInfo::GetOutputArgname(const std::string &value_name,
} }
bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const { bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const {
for (auto &item : inputs_) { for (auto &item : inputs()) {
auto it = std::find(item.second.begin(), item.second.end(), input_name); auto it = std::find(item.second.begin(), item.second.end(), input_name);
if (it != item.second.end()) { if (it != item.second.end()) {
*out = it - item.second.begin(); *out = it - item.second.begin();
...@@ -223,7 +223,7 @@ bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const { ...@@ -223,7 +223,7 @@ bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const {
} }
bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const { bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const {
for (auto &item : outputs_) { for (auto &item : outputs()) {
auto it = std::find(item.second.begin(), item.second.end(), output_name); auto it = std::find(item.second.begin(), item.second.end(), output_name);
if (it != item.second.end()) { if (it != item.second.end()) {
*out = it - item.second.begin(); *out = it - item.second.begin();
......
...@@ -230,7 +230,7 @@ class OpInfo : public cpp::OpDesc { ...@@ -230,7 +230,7 @@ class OpInfo : public cpp::OpDesc {
} }
void UpdateAllInputs(const std::string &from, const std::string &to) { void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) { for (auto &item : *mutable_inputs()) {
for (auto &var : item.second) { for (auto &var : item.second) {
if (var == from) var = to; if (var == from) var = to;
} }
...@@ -238,7 +238,7 @@ class OpInfo : public cpp::OpDesc { ...@@ -238,7 +238,7 @@ class OpInfo : public cpp::OpDesc {
} }
void UpdateAllOutputs(const std::string &from, const std::string &to) { void UpdateAllOutputs(const std::string &from, const std::string &to) {
for (auto &item : outputs_) { for (auto &item : *mutable_outputs()) {
for (auto &var : item.second) { for (auto &var : item.second) {
if (var == from) var = to; if (var == from) var = to;
} }
......
...@@ -62,7 +62,7 @@ class BlockDescWriteAPI { ...@@ -62,7 +62,7 @@ class BlockDescWriteAPI {
virtual ~BlockDescWriteAPI() = default; virtual ~BlockDescWriteAPI() = default;
private: private:
void NotImplemented() { void NotImplemented() const {
LOG(FATAL) << "BlockDescWriteAPI is not available in model read-only mode."; LOG(FATAL) << "BlockDescWriteAPI is not available in model read-only mode.";
} }
}; };
......
...@@ -78,7 +78,7 @@ class OpDescWriteAPI { ...@@ -78,7 +78,7 @@ class OpDescWriteAPI {
virtual ~OpDescWriteAPI() = default; virtual ~OpDescWriteAPI() = default;
private: private:
void NotImplemented() { void NotImplemented() const {
LOG(FATAL) << "OpDescWriteAPI is not available in model read-only mode."; LOG(FATAL) << "OpDescWriteAPI is not available in model read-only mode.";
} }
}; };
......
...@@ -45,7 +45,7 @@ class ProgramDescWriteAPI { ...@@ -45,7 +45,7 @@ class ProgramDescWriteAPI {
virtual ~ProgramDescWriteAPI() = default; virtual ~ProgramDescWriteAPI() = default;
private: private:
void NotImplemented() { void NotImplemented() const {
LOG(FATAL) LOG(FATAL)
<< "ProgramDescWriteAPI is not available in model read-only mode."; << "ProgramDescWriteAPI is not available in model read-only mode.";
} }
......
...@@ -70,7 +70,7 @@ class VarDescWriteAPI { ...@@ -70,7 +70,7 @@ class VarDescWriteAPI {
virtual ~VarDescWriteAPI() = default; virtual ~VarDescWriteAPI() = default;
private: private:
void NotImplemented() { void NotImplemented() const {
LOG(FATAL) << "VarDescWriteAPI is not available in model read-only mode."; LOG(FATAL) << "VarDescWriteAPI is not available in model read-only mode.";
} }
}; };
......
...@@ -65,7 +65,7 @@ class VectorView { ...@@ -65,7 +65,7 @@ class VectorView {
typename Traits::const_iterator begin() const { return cvec_->begin(); } typename Traits::const_iterator begin() const { return cvec_->begin(); }
typename Traits::const_iterator end() const { return cvec_->end(); } typename Traits::const_iterator end() const { return cvec_->end(); }
size_t size() const { return cvec_->size(); } size_t size() const { return cvec_->size(); }
operator std::vector<T>() { operator std::vector<T>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<T> tmp; std::vector<T> tmp;
tmp.reserve(cvec_->size()); tmp.reserve(cvec_->size());
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -83,7 +84,7 @@ class OpDesc : public OpDescAPI { ...@@ -83,7 +84,7 @@ class OpDesc : public OpDescAPI {
} }
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) == nullptr; return desc_->attrs()->LookupByKey(name.c_str()) != nullptr;
} }
size_t AttrsSize() const { return desc_->attrs()->size(); } size_t AttrsSize() const { return desc_->attrs()->size(); }
...@@ -127,6 +128,71 @@ class OpDesc : public OpDescAPI { ...@@ -127,6 +128,71 @@ class OpDesc : public OpDescAPI {
private: private:
proto::OpDesc* desc_; proto::OpDesc* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
// inheritance relationship between the two data types, and the read-only
// version of flatbuffers lacks some write implementations. Therefore, at
// present, we are temporarily providing a default interface that triggers
// execution-time errors to avoid type ambiguity and compile-time errors
// caused by different building options.
public:
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
const std::map<std::string, std::vector<std::string>>& inputs() const {
NotImplemented();
return inputs_;
}
const std::map<std::string, std::vector<std::string>>& outputs() const {
NotImplemented();
return outputs_;
}
std::map<std::string, std::vector<std::string>>* mutable_inputs() {
NotImplemented();
return &inputs_;
}
std::map<std::string, std::vector<std::string>>* mutable_outputs() {
NotImplemented();
return &outputs_;
}
std::vector<std::string> input_vars() const {
NotImplemented();
return std::vector<std::string>();
}
std::vector<std::string> output_vars() const {
NotImplemented();
return std::vector<std::string>();
}
bool HasOutput(const std::string& param) const {
NotImplemented();
return false;
}
const std::map<std::string, Any>& attrs() const {
NotImplemented();
return attrs_;
}
const std::map<std::string, AttrType>& attr_types() const {
NotImplemented();
return attr_types_;
}
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of OpDesc is temporarily "
"unavailable in read-only mode.";
}
std::string type_;
std::map<std::string, std::vector<std::string>> inputs_;
std::map<std::string, std::vector<std::string>> outputs_;
std::map<std::string, Any> attrs_;
std::map<std::string, AttrType> attr_types_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -25,6 +25,7 @@ namespace fbs { ...@@ -25,6 +25,7 @@ namespace fbs {
class ProgramDesc : public ProgramDescAPI { class ProgramDesc : public ProgramDescAPI {
public: public:
ProgramDesc() = default;
explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); } explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); }
size_t BlocksSize() const override { return desc_->blocks()->size(); } size_t BlocksSize() const override { return desc_->blocks()->size(); }
......
...@@ -52,6 +52,30 @@ class VarDesc : public VarDescAPI { ...@@ -52,6 +52,30 @@ class VarDesc : public VarDescAPI {
private: private:
proto::VarDesc* desc_; proto::VarDesc* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
// inheritance relationship between the two data types, and the read-only
// version of flatbuffers lacks some write implementations. Therefore, at
// present, we are temporarily providing a default interface that triggers
// execution-time errors to avoid type ambiguity and compile-time errors
// caused by different building options.
public:
VarDescAPI::Type GetDataType() const {
NotImplemented();
return data_type_;
}
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily "
"unavailable in read-only mode.";
}
Type data_type_;
std::vector<int64_t> shape_;
}; };
} // namespace fbs } // namespace fbs
......
...@@ -112,7 +112,7 @@ class VectorView<std::string, Flatbuffers> { ...@@ -112,7 +112,7 @@ class VectorView<std::string, Flatbuffers> {
return vector_view::FBSStrIterator(cvec_->end()); return vector_view::FBSStrIterator(cvec_->end());
} }
size_t size() const { return cvec_->size(); } size_t size() const { return cvec_->size(); }
operator std::vector<std::string>() { operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp; std::vector<std::string> tmp;
tmp.reserve(cvec_->size()); tmp.reserve(cvec_->size());
......
...@@ -74,7 +74,7 @@ class ConvOpLite : public OpLite { ...@@ -74,7 +74,7 @@ class ConvOpLite : public OpLite {
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings"); std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups"); param_.groups = op_desc.GetAttr<int>("groups");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations"); auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
param_.dilations = std::make_shared<std::vector<int>>(dilations); param_.dilations = std::make_shared<std::vector<int>>(dilations);
......
...@@ -106,7 +106,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -106,7 +106,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc,
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings"); std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups"); param_.groups = op_desc.GetAttr<int>("groups");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations"); auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
......
...@@ -54,7 +54,7 @@ class PoolOpLite : public OpLite { ...@@ -54,7 +54,7 @@ class PoolOpLite : public OpLite {
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize"); param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling"); param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings"); std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
if (op_desc.HasAttr("exclusive")) { if (op_desc.HasAttr("exclusive")) {
param_.exclusive = op_desc.GetAttr<bool>("exclusive"); param_.exclusive = op_desc.GetAttr<bool>("exclusive");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册