diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e8ecd90502933a049cc8f886212579fc061d44ff..a31c5336a1a03084705dd49c566ffd5d4aa25ad9 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -237,6 +237,17 @@ void OpDesc::SetOutput(const std::string ¶m_name, this->outputs_[param_name] = args; } +bool OpDesc::HasAttr(const std::string &name) const { + const proto::OpProto &proto = OpInfoMap::Instance().Get(desc_.type()).Proto(); + for (int i = 0; i != proto.attrs_size(); ++i) { + const proto::OpProto::Attr &attr = proto.attrs(i); + if (attr.name() == name) { + return true; + } + } + return false; +} + proto::AttrType OpDesc::GetAttrType(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 30c8a26c3d2f0068674aa70b4ff875a2f73c1dca..3da7cdcef391fa5d4038cc8db9b512c3e36ff572 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -61,9 +61,7 @@ class OpDesc { void SetOutput(const std::string ¶m_name, const std::vector &args); - bool HasAttr(const std::string &name) const { - return attrs_.find(name) != attrs_.end(); - } + bool HasAttr(const std::string &name) const; proto::AttrType GetAttrType(const std::string &name) const;