未验证 提交 154021ad 编写于 作者: 石晓伟 提交者: GitHub

refactor the cpp:op_desc, test=develop (#3743)

上级 5169197f
......@@ -20,43 +20,6 @@ namespace paddle {
namespace lite {
namespace cpp {
#define SET_ATTR_IMPL(T, repr__) \
template <> \
void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \
attr_types_[name] = AttrType::repr__; \
attrs_[name].set(v); \
}
SET_ATTR_IMPL(int32_t, INT);
SET_ATTR_IMPL(float, FLOAT);
SET_ATTR_IMPL(std::string, STRING);
SET_ATTR_IMPL(bool, BOOLEAN);
SET_ATTR_IMPL(int64_t, LONG);
SET_ATTR_IMPL(std::vector<int>, INTS);
SET_ATTR_IMPL(std::vector<float>, FLOATS);
SET_ATTR_IMPL(std::vector<std::string>, STRINGS);
SET_ATTR_IMPL(std::vector<int64_t>, LONGS);
std::pair<OpDesc::attrs_t::const_iterator, OpDesc::attr_types_t::const_iterator>
FindAttr(const cpp::OpDesc& desc, const std::string& name) {
auto it = desc.attrs().find(name);
CHECK(it != desc.attrs().end()) << "No attributes called " << name
<< " found";
auto attr_it = desc.attr_types().find(name);
CHECK(attr_it != desc.attr_types().end());
return std::make_pair(it, attr_it);
}
#define GET_IMPL_ONE(T, repr__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
auto pair = FindAttr(*this, name); \
CHECK(pair.second->second == AttrType::repr__) \
<< "required type is " << #repr__ << " not match the true type"; \
return pair.first->second.get<T>(); \
}
GET_IMPL_ONE(int32_t, INT)
std::vector<std::string> OpDesc::OutputArgumentNames() const {
std::vector<std::string> res;
for (const auto& x : outputs_) res.push_back(x.first);
......@@ -106,15 +69,6 @@ bool OpDesc::HasOutput(const std::string& param) const {
return it != outputs_.end();
}
GET_IMPL_ONE(float, FLOAT);
GET_IMPL_ONE(std::string, STRING);
GET_IMPL_ONE(int64_t, LONG);
GET_IMPL_ONE(bool, BOOLEAN);
GET_IMPL_ONE(std::vector<int64_t>, LONGS);
GET_IMPL_ONE(std::vector<float>, FLOATS);
GET_IMPL_ONE(std::vector<int>, INTS);
GET_IMPL_ONE(std::vector<std::string>, STRINGS);
} // namespace cpp
} // namespace lite
} // namespace paddle
......@@ -15,6 +15,7 @@
#pragma once
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/desc_apis.h"
#include "lite/utils/any.h"
......@@ -106,10 +107,23 @@ class OpDesc : public OpDescAPI {
}
template <typename T>
void SetAttr(const std::string& name, const T& v);
void SetAttr(const std::string& name, const T& v) {
attr_types_[name] = OpDescAPI::DataTypeTrait<T>::AT;
attrs_[name].set(v);
}
template <typename T>
T GetAttr(const std::string& name) const;
T GetAttr(const std::string& name) const {
auto it = attrs().find(name);
CHECK(it != attrs().end()) << "No attributes called " << name << " found";
auto attr_it = attr_types().find(name);
CHECK(attr_it != attr_types().end());
auto pair = std::make_pair(it, attr_it);
CHECK(pair.second->second == OpDescAPI::DataTypeTrait<T>::AT)
<< "required type is " << OpDescAPI::DataTypeTrait<T>::ATN
<< " not match the true type";
return pair.first->second.get<T>();
}
const std::map<std::string, Any>& attrs() const { return attrs_; }
const std::map<std::string, AttrType>& attr_types() const {
......
......@@ -105,6 +105,12 @@ class OpDescAPI {
UNK,
};
template <AttrType Type>
struct AttrTypeTrait;
template <typename T>
struct DataTypeTrait;
virtual ~OpDescAPI() = default;
/// Get operator's type.
......@@ -162,6 +168,28 @@ class OpDescAPI {
}
};
#define TYPE_TRAIT_IMPL(T, type__) \
template <> \
struct OpDescAPI::AttrTypeTrait<OpDescAPI::AttrType::T> { \
typedef type__ DT; \
}; \
template <> \
struct OpDescAPI::DataTypeTrait<type__> { \
static constexpr AttrType AT = OpDescAPI::AttrType::T; \
static constexpr const char* ATN = #T; \
};
TYPE_TRAIT_IMPL(INT, int32_t);
TYPE_TRAIT_IMPL(FLOAT, float);
TYPE_TRAIT_IMPL(STRING, std::string);
TYPE_TRAIT_IMPL(BOOLEAN, bool);
TYPE_TRAIT_IMPL(LONG, int64_t);
TYPE_TRAIT_IMPL(INTS, std::vector<int>);
TYPE_TRAIT_IMPL(FLOATS, std::vector<float>);
TYPE_TRAIT_IMPL(STRINGS, std::vector<std::string>);
TYPE_TRAIT_IMPL(LONGS, std::vector<int64_t>);
#undef TYPE_TRAIT_IMPL
class BlockDescAPI {
public:
virtual ~BlockDescAPI() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册