diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index a2f9d9040673623f08fb59edf1bd6a57b88f130d..1b15ca67462576443c594aa12ffb9620c169c212 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -425,13 +425,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { : global_scope_->GetMutableScope(); auto op_with_kernel = dynamic_cast(op); { + // If it is OperatorBase, InferShape do nothing. if (op_with_kernel != nullptr) { platform::RecordEvent infershape_event( "infer_shape", platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); - // If it is OperatorBase, InferShape do nothing. - op_with_kernel->Info().infer_shape_( - instr_node.InnerInferShapeContext().get()); + + // see OperatorWithKernel::RunImpl in operator.cc for why + if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) && + op_with_kernel->Attr(kAllKernelsMustComputeRuntimeShape))) { + op_with_kernel->Info().infer_shape_( + instr_node.InnerInferShapeContext().get()); + } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index d56082a91a61f8aa28f679e482760eb777b07dbc..360e0222a516c6220569467f04e62ad3c0d4e41b 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -29,6 +29,8 @@ PADDLE_DEFINE_EXPORTED_bool( new_executor_sequential_run, false, "Enable sequential execution for standalone executor, used for debug"); +DECLARE_bool(use_mkldnn); + namespace paddle { namespace framework { namespace interpreter { @@ -192,6 +194,7 @@ void create_all_ops(const framework::BlockDesc& block, const VariableNameMap& inputs_names = op->Inputs(); const VariableNameMap& outputs_names = op->Outputs(); + AttributeMap op_attr_map = op->GetAttrMap(); if (info.Checker() != nullptr) { @@ -199,6 +202,16 @@ void create_all_ops(const framework::BlockDesc& block, } auto op_base = info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); + +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) { + if (op->HasAttr("use_mkldnn")) { + VLOG(4) << "Set use_mkldnn=True for " << op_base->Type(); + op_base->SetAttr("use_mkldnn", true); + } + } +#endif + ops->emplace_back(std::unique_ptr(op_base)); } }