未验证 提交 37cb221e 编写于 作者: 石晓伟 提交者: GitHub

align opdesc interfaces, test=develop (#3922)

* align opdesc interfaces, test=develop

* align opdesc interfaces, test=develop
上级 e1e2bf38
......@@ -189,7 +189,7 @@ void OpLite::AttachOutput(const cpp::OpDesc &op_desc,
bool OpInfo::GetInputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : inputs_) {
for (auto &item : inputs()) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
......@@ -201,7 +201,7 @@ bool OpInfo::GetInputArgname(const std::string &value_name,
bool OpInfo::GetOutputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : outputs_) {
for (auto &item : outputs()) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
......@@ -212,7 +212,7 @@ bool OpInfo::GetOutputArgname(const std::string &value_name,
}
bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const {
for (auto &item : inputs_) {
for (auto &item : inputs()) {
auto it = std::find(item.second.begin(), item.second.end(), input_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
......@@ -223,7 +223,7 @@ bool OpInfo::GetInputIndex(const std::string &input_name, int *out) const {
}
bool OpInfo::GetOutputIndex(const std::string &output_name, int *out) const {
for (auto &item : outputs_) {
for (auto &item : outputs()) {
auto it = std::find(item.second.begin(), item.second.end(), output_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
......
......@@ -230,7 +230,7 @@ class OpInfo : public cpp::OpDesc {
}
void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) {
for (auto &item : *mutable_inputs()) {
for (auto &var : item.second) {
if (var == from) var = to;
}
......@@ -238,7 +238,7 @@ class OpInfo : public cpp::OpDesc {
}
void UpdateAllOutputs(const std::string &from, const std::string &to) {
for (auto &item : outputs_) {
for (auto &item : *mutable_outputs()) {
for (auto &var : item.second) {
if (var == from) var = to;
}
......
......@@ -62,7 +62,7 @@ class BlockDescWriteAPI {
virtual ~BlockDescWriteAPI() = default;
private:
void NotImplemented() {
void NotImplemented() const {
LOG(FATAL) << "BlockDescWriteAPI is not available in model read-only mode.";
}
};
......
......@@ -78,7 +78,7 @@ class OpDescWriteAPI {
virtual ~OpDescWriteAPI() = default;
private:
void NotImplemented() {
void NotImplemented() const {
LOG(FATAL) << "OpDescWriteAPI is not available in model read-only mode.";
}
};
......
......@@ -45,7 +45,7 @@ class ProgramDescWriteAPI {
virtual ~ProgramDescWriteAPI() = default;
private:
void NotImplemented() {
void NotImplemented() const {
LOG(FATAL)
<< "ProgramDescWriteAPI is not available in model read-only mode.";
}
......
......@@ -70,7 +70,7 @@ class VarDescWriteAPI {
virtual ~VarDescWriteAPI() = default;
private:
void NotImplemented() {
void NotImplemented() const {
LOG(FATAL) << "VarDescWriteAPI is not available in model read-only mode.";
}
};
......
......@@ -65,7 +65,7 @@ class VectorView {
typename Traits::const_iterator begin() const { return cvec_->begin(); }
typename Traits::const_iterator end() const { return cvec_->end(); }
size_t size() const { return cvec_->size(); }
operator std::vector<T>() {
operator std::vector<T>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<T> tmp;
tmp.reserve(cvec_->size());
......
......@@ -14,6 +14,7 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
......@@ -83,7 +84,7 @@ class OpDesc : public OpDescAPI {
}
bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) == nullptr;
return desc_->attrs()->LookupByKey(name.c_str()) != nullptr;
}
size_t AttrsSize() const { return desc_->attrs()->size(); }
......@@ -127,6 +128,71 @@ class OpDesc : public OpDescAPI {
private:
proto::OpDesc* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
// inheritance relationship between the two data types, and the read-only
// version of flatbuffers lacks some write implementations. Therefore, at
// present, we are temporarily providing a default interface that triggers
// execution-time errors to avoid type ambiguity and compile-time errors
// caused by different building options.
public:
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
const std::map<std::string, std::vector<std::string>>& inputs() const {
NotImplemented();
return inputs_;
}
const std::map<std::string, std::vector<std::string>>& outputs() const {
NotImplemented();
return outputs_;
}
std::map<std::string, std::vector<std::string>>* mutable_inputs() {
NotImplemented();
return &inputs_;
}
std::map<std::string, std::vector<std::string>>* mutable_outputs() {
NotImplemented();
return &outputs_;
}
std::vector<std::string> input_vars() const {
NotImplemented();
return std::vector<std::string>();
}
std::vector<std::string> output_vars() const {
NotImplemented();
return std::vector<std::string>();
}
bool HasOutput(const std::string& param) const {
NotImplemented();
return false;
}
const std::map<std::string, Any>& attrs() const {
NotImplemented();
return attrs_;
}
const std::map<std::string, AttrType>& attr_types() const {
NotImplemented();
return attr_types_;
}
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of OpDesc is temporarily "
"unavailable in read-only mode.";
}
std::string type_;
std::map<std::string, std::vector<std::string>> inputs_;
std::map<std::string, std::vector<std::string>> outputs_;
std::map<std::string, Any> attrs_;
std::map<std::string, AttrType> attr_types_;
};
} // namespace fbs
......
......@@ -25,6 +25,7 @@ namespace fbs {
class ProgramDesc : public ProgramDescAPI {
public:
ProgramDesc() = default;
explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); }
size_t BlocksSize() const override { return desc_->blocks()->size(); }
......
......@@ -52,6 +52,30 @@ class VarDesc : public VarDescAPI {
private:
proto::VarDesc* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
// inheritance relationship between the two data types, and the read-only
// version of flatbuffers lacks some write implementations. Therefore, at
// present, we are temporarily providing a default interface that triggers
// execution-time errors to avoid type ambiguity and compile-time errors
// caused by different building options.
public:
VarDescAPI::Type GetDataType() const {
NotImplemented();
return data_type_;
}
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily "
"unavailable in read-only mode.";
}
Type data_type_;
std::vector<int64_t> shape_;
};
} // namespace fbs
......
......@@ -112,7 +112,7 @@ class VectorView<std::string, Flatbuffers> {
return vector_view::FBSStrIterator(cvec_->end());
}
size_t size() const { return cvec_->size(); }
operator std::vector<std::string>() {
operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp;
tmp.reserve(cvec_->size());
......
......@@ -74,7 +74,7 @@ class ConvOpLite : public OpLite {
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
param_.dilations = std::make_shared<std::vector<int>>(dilations);
......
......@@ -106,7 +106,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc,
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
......
......@@ -54,7 +54,7 @@ class PoolOpLite : public OpLite {
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
if (op_desc.HasAttr("exclusive")) {
param_.exclusive = op_desc.GetAttr<bool>("exclusive");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册