From 057e8102866e0f2356050d7cf4bb8fcab968267b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 22 Sep 2017 18:15:43 -0700 Subject: [PATCH] Simplify GetAttrType code --- paddle/framework/attribute.cc | 41 ----------------------------------- paddle/framework/attribute.h | 14 +++++++----- 2 files changed, 9 insertions(+), 46 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index fda89252e3..e705da0131 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -28,47 +28,6 @@ ProgramDesc& GetProgramDesc() { return *g_program_desc; } -template <> -AttrType AttrTypeID() { - return BOOLEAN; -} -template <> -AttrType AttrTypeID() { - return INT; -} -template <> -AttrType AttrTypeID() { - return FLOAT; -} -template <> -AttrType AttrTypeID() { - return STRING; -} -template <> -AttrType AttrTypeID>() { - return BOOLEANS; -} -template <> -AttrType AttrTypeID>() { - return INTS; -} -template <> -AttrType AttrTypeID>() { - return FLOATS; -} -template <> -AttrType AttrTypeID>() { - return STRINGS; -} -template <> -AttrType AttrTypeID>>() { - return INT_PAIRS; -} -template <> -AttrType AttrTypeID() { - return BLOCK; -} - Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case framework::AttrType::BOOLEAN: { diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 48b54b5422..13f2877226 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -27,10 +27,11 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef boost::variant, std::vector, std::vector, - std::vector, - std::vector>, BlockDesc*> +// The order should be as same as framework.proto +typedef boost::variant, + std::vector, std::vector, + std::vector>, bool, + std::vector, BlockDesc*> Attribute; typedef std::unordered_map AttributeMap; @@ -38,7 +39,10 @@ typedef std::unordered_map AttributeMap; ProgramDesc& GetProgramDesc(); template -AttrType AttrTypeID(); +inline AttrType AttrTypeID() { + Attribute tmp = T(); + return static_cast(tmp.which() - 1); +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc); -- GitLab