未验证 提交 b972b0df 编写于 作者: C Chen Weihang 提交者: GitHub

polish attr get impl (#42337)

上级 22d3c560
......@@ -18,35 +18,37 @@ namespace paddle {
namespace framework {
paddle::any GetAttrValue(const Attribute& attr) {
if (attr.type() == typeid(int)) {
return paddle::any(BOOST_GET_CONST(int, attr));
} else if (attr.type() == typeid(float)) {
return paddle::any(BOOST_GET_CONST(float, attr));
} else if (attr.type() == typeid(std::string)) {
return paddle::any(BOOST_GET_CONST(std::string, attr));
} else if (attr.type() == typeid(std::vector<int>)) {
return paddle::any(BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr.type() == typeid(std::vector<float>)) {
return paddle::any(BOOST_GET_CONST(std::vector<float>, attr));
} else if (attr.type() == typeid(std::vector<std::string>)) {
return paddle::any(BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr.type() == typeid(bool)) {
return paddle::any(BOOST_GET_CONST(bool, attr));
} else if (attr.type() == typeid(std::vector<bool>)) {
return paddle::any(BOOST_GET_CONST(std::vector<bool>, attr));
} else if (attr.type() == typeid(BlockDesc*)) {
return paddle::any(BOOST_GET_CONST(BlockDesc*, attr));
} else if (attr.type() == typeid(int64_t)) {
return paddle::any(BOOST_GET_CONST(int64_t, attr));
} else if (attr.type() == typeid(std::vector<BlockDesc*>)) {
return paddle::any(BOOST_GET_CONST(std::vector<BlockDesc*>, attr));
} else if (attr.type() == typeid(std::vector<int64_t>)) {
return paddle::any(BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (attr.type() == typeid(std::vector<double>)) {
return paddle::any(BOOST_GET_CONST(std::vector<double>, attr));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported Attribute value type."));
switch (AttrTypeID(attr)) {
case proto::AttrType::INT:
return BOOST_GET_CONST(int, attr);
case proto::AttrType::FLOAT:
return BOOST_GET_CONST(float, attr);
case proto::AttrType::STRING:
return BOOST_GET_CONST(std::string, attr);
case proto::AttrType::INTS:
return BOOST_GET_CONST(std::vector<int>, attr);
case proto::AttrType::FLOATS:
return BOOST_GET_CONST(std::vector<float>, attr);
case proto::AttrType::STRINGS:
return BOOST_GET_CONST(std::vector<std::string>, attr);
case proto::AttrType::BOOLEAN:
return BOOST_GET_CONST(bool, attr);
case proto::AttrType::BOOLEANS:
return BOOST_GET_CONST(std::vector<bool>, attr);
case proto::AttrType::LONG:
return BOOST_GET_CONST(int64_t, attr);
case proto::AttrType::LONGS:
return BOOST_GET_CONST(std::vector<int64_t>, attr);
case proto::AttrType::FLOAT64S:
return BOOST_GET_CONST(std::vector<double>, attr);
case proto::AttrType::BLOCK:
return BOOST_GET_CONST(BlockDesc*, attr);
case proto::AttrType::BLOCKS:
return BOOST_GET_CONST(std::vector<BlockDesc*>, attr);
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Attribute value type `%s` for phi.",
platform::demangle(attr.type().name())));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册