未验证 提交 cb124156 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] support to enable mkldnn by flags (#41274)

上级 0fe2001a
...@@ -425,13 +425,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -425,13 +425,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
: global_scope_->GetMutableScope(); : global_scope_->GetMutableScope();
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{ {
// If it is OperatorBase, InferShape do nothing.
if (op_with_kernel != nullptr) { if (op_with_kernel != nullptr) {
platform::RecordEvent infershape_event( platform::RecordEvent infershape_event(
"infer_shape", platform::TracerEventType::OperatorInner, 1, "infer_shape", platform::TracerEventType::OperatorInner, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
// If it is OperatorBase, InferShape do nothing.
op_with_kernel->Info().infer_shape_( // see OperatorWithKernel::RunImpl in operator.cc for why
instr_node.InnerInferShapeContext().get()); if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op_with_kernel->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
op_with_kernel->Info().infer_shape_(
instr_node.InnerInferShapeContext().get());
}
} }
} }
......
...@@ -29,6 +29,8 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -29,6 +29,8 @@ PADDLE_DEFINE_EXPORTED_bool(
new_executor_sequential_run, false, new_executor_sequential_run, false,
"Enable sequential execution for standalone executor, used for debug"); "Enable sequential execution for standalone executor, used for debug");
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpreter { namespace interpreter {
...@@ -192,6 +194,7 @@ void create_all_ops(const framework::BlockDesc& block, ...@@ -192,6 +194,7 @@ void create_all_ops(const framework::BlockDesc& block,
const VariableNameMap& inputs_names = op->Inputs(); const VariableNameMap& inputs_names = op->Inputs();
const VariableNameMap& outputs_names = op->Outputs(); const VariableNameMap& outputs_names = op->Outputs();
AttributeMap op_attr_map = op->GetAttrMap(); AttributeMap op_attr_map = op->GetAttrMap();
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
...@@ -199,6 +202,16 @@ void create_all_ops(const framework::BlockDesc& block, ...@@ -199,6 +202,16 @@ void create_all_ops(const framework::BlockDesc& block,
} }
auto op_base = auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
if (op->HasAttr("use_mkldnn")) {
VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
op_base->SetAttr("use_mkldnn", true);
}
}
#endif
ops->emplace_back(std::unique_ptr<OperatorBase>(op_base)); ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册