From cb12415622351b82bbfab8df67985e844019281b Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Apr 2022 11:06:36 +0800 Subject: [PATCH] [new-exec] support to enable mkldnn by flags (#41274) --- .../fluid/framework/new_executor/interpretercore.cc | 11 ++++++++--- .../framework/new_executor/interpretercore_util.cc | 13 +++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index a2f9d904067..1b15ca67462 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -425,13 +425,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { : global_scope_->GetMutableScope(); auto op_with_kernel = dynamic_cast(op); { + // If it is OperatorBase, InferShape do nothing. if (op_with_kernel != nullptr) { platform::RecordEvent infershape_event( "infer_shape", platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); - // If it is OperatorBase, InferShape do nothing. - op_with_kernel->Info().infer_shape_( - instr_node.InnerInferShapeContext().get()); + + // see OperatorWithKernel::RunImpl in operator.cc for why + if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) && + op_with_kernel->Attr(kAllKernelsMustComputeRuntimeShape))) { + op_with_kernel->Info().infer_shape_( + instr_node.InnerInferShapeContext().get()); + } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index d56082a91a6..360e0222a51 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -29,6 +29,8 @@ PADDLE_DEFINE_EXPORTED_bool( new_executor_sequential_run, false, "Enable sequential execution for standalone executor, used for debug"); +DECLARE_bool(use_mkldnn); + namespace paddle { namespace framework { namespace interpreter { @@ -192,6 +194,7 @@ void create_all_ops(const framework::BlockDesc& block, const VariableNameMap& inputs_names = op->Inputs(); const VariableNameMap& outputs_names = op->Outputs(); + AttributeMap op_attr_map = op->GetAttrMap(); if (info.Checker() != nullptr) { @@ -199,6 +202,16 @@ void create_all_ops(const framework::BlockDesc& block, } auto op_base = info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); + +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) { + if (op->HasAttr("use_mkldnn")) { + VLOG(4) << "Set use_mkldnn=True for " << op_base->Type(); + op_base->SetAttr("use_mkldnn", true); + } + } +#endif + ops->emplace_back(std::unique_ptr(op_base)); } } -- GitLab