未验证 提交 cb124156 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] support to enable mkldnn by flags (#41274)

上级 0fe2001a
......@@ -425,13 +425,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
: global_scope_->GetMutableScope();
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{
// If it is OperatorBase, InferShape do nothing.
if (op_with_kernel != nullptr) {
platform::RecordEvent infershape_event(
"infer_shape", platform::TracerEventType::OperatorInner, 1,
platform::EventRole::kInnerOp);
// If it is OperatorBase, InferShape do nothing.
op_with_kernel->Info().infer_shape_(
instr_node.InnerInferShapeContext().get());
// see OperatorWithKernel::RunImpl in operator.cc for why
if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op_with_kernel->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
op_with_kernel->Info().infer_shape_(
instr_node.InnerInferShapeContext().get());
}
}
}
......
......@@ -29,6 +29,8 @@ PADDLE_DEFINE_EXPORTED_bool(
new_executor_sequential_run, false,
"Enable sequential execution for standalone executor, used for debug");
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace framework {
namespace interpreter {
......@@ -192,6 +194,7 @@ void create_all_ops(const framework::BlockDesc& block,
const VariableNameMap& inputs_names = op->Inputs();
const VariableNameMap& outputs_names = op->Outputs();
AttributeMap op_attr_map = op->GetAttrMap();
if (info.Checker() != nullptr) {
......@@ -199,6 +202,16 @@ void create_all_ops(const framework::BlockDesc& block,
}
auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
if (op->HasAttr("use_mkldnn")) {
VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
op_base->SetAttr("use_mkldnn", true);
}
}
#endif
ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册