diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index 4ac885dbe3f9739f78f6fbf66b3b5e57859918d2..7e53cf0c90e9e5ac9c04440d10898d3551bf5f71 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -102,11 +102,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { } bool HasAttr(const std::string& name) const override { - if (attrs_.find(name) == attrs_.end()) { - return &default_attrs_ != nullptr && - default_attrs_.find(name) != default_attrs_.end(); - } - return true; + return attrs_.find(name) != attrs_.end() || + default_attrs_.find(name) != default_attrs_.end(); } const framework::AttributeMap& Attrs() const override { return attrs_; } diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index cee0ad250730ec65c270c449617f13b919177ace..fedcdcbf93a86a655ccca8871f1f275163edabc2 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -30,6 +30,8 @@ namespace paddle { namespace imperative { +const static framework::AttributeMap empty_default_attr_map; // NOLINT + // TODO(zjl): to support py_func layer class OpBase { public: @@ -123,7 +125,12 @@ class OpBase { const framework::AttributeMap& Attrs() { return attrs_; } - const framework::AttributeMap& DefaultAttrsMap() { return *default_attrs_; } + const framework::AttributeMap& DefaultAttrsMap() { + if (default_attrs_ == nullptr) { + return empty_default_attr_map; + } + return *default_attrs_; + } bool HasAttr(const std::string& name) const { VLOG(6) << "Default attrs: " << default_attrs_;