提交 d1e85e33 编写于 作者: T tangwei12

shape type to int64_t, test=develop

上级 39b3bf24
......@@ -26,6 +26,113 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<std::vector<int64_t>> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
std::vector<int64_t>* operator()(Attribute& attr) const {
if (attr.type() == typeid(std::vector<int>)) { // NOLINT
std::vector<int> val = boost::get<std::vector<int>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
} else if (attr.type() == typeid(std::vector<float>)) { // NOLINT
std::vector<float> val = boost::get<std::vector<float>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
}
std::vector<int64_t>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T>
inline proto::AttrType AttrTypeID() {
Attribute tmp = T();
......@@ -42,7 +149,11 @@ class AttrReader {
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
}
private:
......@@ -82,7 +193,7 @@ class DefaultValueSetter {
public:
explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; }
void operator()(T& value) const { value = default_value_; } // NOLINT
private:
T default_value_;
......@@ -117,84 +228,6 @@ class EnumInContainer {
std::unordered_set<T> container_;
};
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template <typename T>
......@@ -235,7 +268,7 @@ class TypedAttrChecker {
return *this;
}
void operator()(AttributeMap& attr_map) const {
void operator()(AttributeMap& attr_map) const { // NOLINT
if (!attr_map.count(attr_name_)) {
// user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(),
......@@ -271,7 +304,7 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}
void Check(AttributeMap& attr_map) const {
void Check(AttributeMap& attr_map) const { // NOLINT
for (const auto& checker : attr_checkers_) {
checker(attr_map);
}
......
......@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<BlockDesc *> &v) const {
std::vector<int> blocks_idx;
for (auto blk : v) {
blocks_idx.push_back(blk->ID());
blocks_idx.push_sback(blk->ID());
}
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(BlockDesapply_visitorc *desc) const {
attr_->set_block_idx(desc->ID());
}
void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(const std::vector<int64_t> &v) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册