diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index fda89252e35c382468877e8cab148e5f91d77ac2..e705da0131aa48442ddd73a467c31b53d2938fd6 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -28,47 +28,6 @@ ProgramDesc& GetProgramDesc() { return *g_program_desc; } -template <> -AttrType AttrTypeID() { - return BOOLEAN; -} -template <> -AttrType AttrTypeID() { - return INT; -} -template <> -AttrType AttrTypeID() { - return FLOAT; -} -template <> -AttrType AttrTypeID() { - return STRING; -} -template <> -AttrType AttrTypeID>() { - return BOOLEANS; -} -template <> -AttrType AttrTypeID>() { - return INTS; -} -template <> -AttrType AttrTypeID>() { - return FLOATS; -} -template <> -AttrType AttrTypeID>() { - return STRINGS; -} -template <> -AttrType AttrTypeID>>() { - return INT_PAIRS; -} -template <> -AttrType AttrTypeID() { - return BLOCK; -} - Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case framework::AttrType::BOOLEAN: { diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 48b54b5422de8c45e15a1b7040b78373dce8fa3a..13f2877226fe876b8448bf4ce7e1dc77149b79c6 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -27,10 +27,11 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef boost::variant, std::vector, std::vector, - std::vector, - std::vector>, BlockDesc*> +// The order should be as same as framework.proto +typedef boost::variant, + std::vector, std::vector, + std::vector>, bool, + std::vector, BlockDesc*> Attribute; typedef std::unordered_map AttributeMap; @@ -38,7 +39,10 @@ typedef std::unordered_map AttributeMap; ProgramDesc& GetProgramDesc(); template -AttrType AttrTypeID(); +inline AttrType AttrTypeID() { + Attribute tmp = T(); + return static_cast(tmp.which() - 1); +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc);