提交 d1e85e33 编写于 作者: T tangwei12

shape type to int64_t, test=develop

上级 39b3bf24
...@@ -26,6 +26,113 @@ limitations under the License. */ ...@@ -26,6 +26,113 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<std::vector<int64_t>> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
std::vector<int64_t>* operator()(Attribute& attr) const {
if (attr.type() == typeid(std::vector<int>)) { // NOLINT
std::vector<int> val = boost::get<std::vector<int>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
} else if (attr.type() == typeid(std::vector<float>)) { // NOLINT
std::vector<float> val = boost::get<std::vector<float>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
}
std::vector<int64_t>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
...@@ -42,7 +149,11 @@ class AttrReader { ...@@ -42,7 +149,11 @@ class AttrReader {
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name); name);
return boost::get<T>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
} }
private: private:
...@@ -82,7 +193,7 @@ class DefaultValueSetter { ...@@ -82,7 +193,7 @@ class DefaultValueSetter {
public: public:
explicit DefaultValueSetter(T default_value) explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {} : default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; } void operator()(T& value) const { value = default_value_; } // NOLINT
private: private:
T default_value_; T default_value_;
...@@ -117,84 +228,6 @@ class EnumInContainer { ...@@ -117,84 +228,6 @@ class EnumInContainer {
std::unordered_set<T> container_; std::unordered_set<T> container_;
}; };
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
...@@ -235,7 +268,7 @@ class TypedAttrChecker { ...@@ -235,7 +268,7 @@ class TypedAttrChecker {
return *this; return *this;
} }
void operator()(AttributeMap& attr_map) const { void operator()(AttributeMap& attr_map) const { // NOLINT
if (!attr_map.count(attr_name_)) { if (!attr_map.count(attr_name_)) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE(!default_value_setter_.empty(),
...@@ -271,7 +304,7 @@ class OpAttrChecker { ...@@ -271,7 +304,7 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>()); return *(checker.target<TypedAttrChecker<T>>());
} }
void Check(AttributeMap& attr_map) const { void Check(AttributeMap& attr_map) const { // NOLINT
for (const auto& checker : attr_checkers_) { for (const auto& checker : attr_checkers_) {
checker(attr_map); checker(attr_map);
} }
......
...@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<BlockDesc *> &v) const { void operator()(const std::vector<BlockDesc *> &v) const {
std::vector<int> blocks_idx; std::vector<int> blocks_idx;
for (auto blk : v) { for (auto blk : v) {
blocks_idx.push_back(blk->ID()); blocks_idx.push_sback(blk->ID());
} }
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx()); VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesapply_visitorc *desc) const {
attr_->set_block_idx(desc->ID());
}
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(const std::vector<int64_t> &v) const { void operator()(const std::vector<int64_t> &v) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册