未验证 提交 b972b0df 编写于 作者: C Chen Weihang 提交者: GitHub

polish attr get impl (#42337)

上级 22d3c560
...@@ -18,35 +18,37 @@ namespace paddle { ...@@ -18,35 +18,37 @@ namespace paddle {
namespace framework { namespace framework {
paddle::any GetAttrValue(const Attribute& attr) { paddle::any GetAttrValue(const Attribute& attr) {
if (attr.type() == typeid(int)) { switch (AttrTypeID(attr)) {
return paddle::any(BOOST_GET_CONST(int, attr)); case proto::AttrType::INT:
} else if (attr.type() == typeid(float)) { return BOOST_GET_CONST(int, attr);
return paddle::any(BOOST_GET_CONST(float, attr)); case proto::AttrType::FLOAT:
} else if (attr.type() == typeid(std::string)) { return BOOST_GET_CONST(float, attr);
return paddle::any(BOOST_GET_CONST(std::string, attr)); case proto::AttrType::STRING:
} else if (attr.type() == typeid(std::vector<int>)) { return BOOST_GET_CONST(std::string, attr);
return paddle::any(BOOST_GET_CONST(std::vector<int>, attr)); case proto::AttrType::INTS:
} else if (attr.type() == typeid(std::vector<float>)) { return BOOST_GET_CONST(std::vector<int>, attr);
return paddle::any(BOOST_GET_CONST(std::vector<float>, attr)); case proto::AttrType::FLOATS:
} else if (attr.type() == typeid(std::vector<std::string>)) { return BOOST_GET_CONST(std::vector<float>, attr);
return paddle::any(BOOST_GET_CONST(std::vector<std::string>, attr)); case proto::AttrType::STRINGS:
} else if (attr.type() == typeid(bool)) { return BOOST_GET_CONST(std::vector<std::string>, attr);
return paddle::any(BOOST_GET_CONST(bool, attr)); case proto::AttrType::BOOLEAN:
} else if (attr.type() == typeid(std::vector<bool>)) { return BOOST_GET_CONST(bool, attr);
return paddle::any(BOOST_GET_CONST(std::vector<bool>, attr)); case proto::AttrType::BOOLEANS:
} else if (attr.type() == typeid(BlockDesc*)) { return BOOST_GET_CONST(std::vector<bool>, attr);
return paddle::any(BOOST_GET_CONST(BlockDesc*, attr)); case proto::AttrType::LONG:
} else if (attr.type() == typeid(int64_t)) { return BOOST_GET_CONST(int64_t, attr);
return paddle::any(BOOST_GET_CONST(int64_t, attr)); case proto::AttrType::LONGS:
} else if (attr.type() == typeid(std::vector<BlockDesc*>)) { return BOOST_GET_CONST(std::vector<int64_t>, attr);
return paddle::any(BOOST_GET_CONST(std::vector<BlockDesc*>, attr)); case proto::AttrType::FLOAT64S:
} else if (attr.type() == typeid(std::vector<int64_t>)) { return BOOST_GET_CONST(std::vector<double>, attr);
return paddle::any(BOOST_GET_CONST(std::vector<int64_t>, attr)); case proto::AttrType::BLOCK:
} else if (attr.type() == typeid(std::vector<double>)) { return BOOST_GET_CONST(BlockDesc*, attr);
return paddle::any(BOOST_GET_CONST(std::vector<double>, attr)); case proto::AttrType::BLOCKS:
} else { return BOOST_GET_CONST(std::vector<BlockDesc*>, attr);
PADDLE_THROW( default:
platform::errors::Unimplemented("Unsupported Attribute value type.")); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Attribute value type `%s` for phi.",
platform::demangle(attr.type().name())));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册