diff --git a/paddle/fluid/framework/attribute.cc b/paddle/fluid/framework/attribute.cc index cf7a7c3c9f43dde58cc356fe5dc8e7f92bc1053f..2599e3232cac7657429a47c226e74a9f9425bb4c 100644 --- a/paddle/fluid/framework/attribute.cc +++ b/paddle/fluid/framework/attribute.cc @@ -18,35 +18,37 @@ namespace paddle { namespace framework { paddle::any GetAttrValue(const Attribute& attr) { - if (attr.type() == typeid(int)) { - return paddle::any(BOOST_GET_CONST(int, attr)); - } else if (attr.type() == typeid(float)) { - return paddle::any(BOOST_GET_CONST(float, attr)); - } else if (attr.type() == typeid(std::string)) { - return paddle::any(BOOST_GET_CONST(std::string, attr)); - } else if (attr.type() == typeid(std::vector)) { - return paddle::any(BOOST_GET_CONST(std::vector, attr)); - } else if (attr.type() == typeid(std::vector)) { - return paddle::any(BOOST_GET_CONST(std::vector, attr)); - } else if (attr.type() == typeid(std::vector)) { - return paddle::any(BOOST_GET_CONST(std::vector, attr)); - } else if (attr.type() == typeid(bool)) { - return paddle::any(BOOST_GET_CONST(bool, attr)); - } else if (attr.type() == typeid(std::vector)) { - return paddle::any(BOOST_GET_CONST(std::vector, attr)); - } else if (attr.type() == typeid(BlockDesc*)) { - return paddle::any(BOOST_GET_CONST(BlockDesc*, attr)); - } else if (attr.type() == typeid(int64_t)) { - return paddle::any(BOOST_GET_CONST(int64_t, attr)); - } else if (attr.type() == typeid(std::vector)) { - return paddle::any(BOOST_GET_CONST(std::vector, attr)); - } else if (attr.type() == typeid(std::vector)) { - return paddle::any(BOOST_GET_CONST(std::vector, attr)); - } else if (attr.type() == typeid(std::vector)) { - return paddle::any(BOOST_GET_CONST(std::vector, attr)); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("Unsupported Attribute value type.")); + switch (AttrTypeID(attr)) { + case proto::AttrType::INT: + return BOOST_GET_CONST(int, attr); + case proto::AttrType::FLOAT: + return BOOST_GET_CONST(float, attr); + case proto::AttrType::STRING: + return BOOST_GET_CONST(std::string, attr); + case proto::AttrType::INTS: + return BOOST_GET_CONST(std::vector, attr); + case proto::AttrType::FLOATS: + return BOOST_GET_CONST(std::vector, attr); + case proto::AttrType::STRINGS: + return BOOST_GET_CONST(std::vector, attr); + case proto::AttrType::BOOLEAN: + return BOOST_GET_CONST(bool, attr); + case proto::AttrType::BOOLEANS: + return BOOST_GET_CONST(std::vector, attr); + case proto::AttrType::LONG: + return BOOST_GET_CONST(int64_t, attr); + case proto::AttrType::LONGS: + return BOOST_GET_CONST(std::vector, attr); + case proto::AttrType::FLOAT64S: + return BOOST_GET_CONST(std::vector, attr); + case proto::AttrType::BLOCK: + return BOOST_GET_CONST(BlockDesc*, attr); + case proto::AttrType::BLOCKS: + return BOOST_GET_CONST(std::vector, attr); + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported Attribute value type `%s` for phi.", + platform::demangle(attr.type().name()))); } }