From 9adad42df14b3b8493731714ecd115d073417f48 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 3 Nov 2022 14:02:49 +0800 Subject: [PATCH] [Opt Kernel Selection] Opt CanMKLDNNBeUsed performance (#47563) * opt CanMKLDNNBeUsed performance * fix nullptr bug * fix OpBase default_attrs=nullptr bug * fix OpBase default_attrs=nullptr bug * fix OpBase default_attrs=nullptr bug --- paddle/fluid/framework/operator.cc | 10 ++-------- paddle/fluid/imperative/dygraph_grad_maker.h | 2 +- paddle/fluid/imperative/execution_context.h | 3 ++- paddle/fluid/imperative/op_base.h | 2 +- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index b471440768..d8312c698b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1414,16 +1414,10 @@ bool OperatorWithKernel::SupportsKernelType( bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const { - // NOTE(jiahongyu): Only mkldnn kernels need to check "use_mkldnn" attribute, - // hence we first call function SupportsMKLDNN. If we check "use_mkldnn" - // attribute first, it will cause error because some codes add "use_mkldnn" - // attribute to non-mkldnn ops. - if (!this->SupportsMKLDNN(data_type)) { - return false; - } const std::string use_mkldnn_attr = "use_mkldnn"; return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr(use_mkldnn_attr) && - platform::is_cpu_place(ctx.GetPlace()); + platform::is_cpu_place(ctx.GetPlace()) && + this->SupportsMKLDNN(data_type); } void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index aaa7f9fa41..e0c943f18c 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -209,7 +209,7 @@ class GradOpBaseMakerBase { const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_out_; const framework::AttributeMap& attrs_; - const framework::AttributeMap* default_attrs_; + const framework::AttributeMap* default_attrs_ = nullptr; const std::map& inplace_map_; }; diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index 6d4f7c347b..4ac885dbe3 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -103,7 +103,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { bool HasAttr(const std::string& name) const override { if (attrs_.find(name) == attrs_.end()) { - return default_attrs_.find(name) != default_attrs_.end(); + return &default_attrs_ != nullptr && + default_attrs_.find(name) != default_attrs_.end(); } return true; } diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 3faa9a0cfb..cee0ad2507 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -221,7 +221,7 @@ class OpBase { NameVarMap ins_; NameVarMap outs_; framework::AttributeMap attrs_; - const framework::AttributeMap* default_attrs_; + const framework::AttributeMap* default_attrs_ = nullptr; std::unique_ptr op_; platform::Place place_; size_t id_{-1UL}; -- GitLab