未验证 提交 0a7e6f90 编写于 作者: H Hui Zhang 提交者: GitHub

[jit] pe engine with mkldnn (#45728)

* using mkldnn

* using with mkldnn macro

* fix use mkldnn
上级 1967c6a6
...@@ -170,8 +170,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -170,8 +170,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendOpFusePasses() { void AppendOpFusePasses() {
// 1. infernce pass if enabled. // 1. infernce pass if enabled.
AppendPassWithCheck(strategy_.inference_ && strategy_.del_dropout_, AppendPassWithCheck(
strategy_.enable_inference_pass_ && strategy_.delete_dropout_,
"delete_dropout_op_x_pass"); "delete_dropout_op_x_pass");
AppendPassWithCheck(
strategy_.enable_inference_pass_ && strategy_.use_mkldnn_,
"mkldnn_placement_pass");
// 2. trainning pass // 2. trainning pass
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
......
...@@ -148,8 +148,13 @@ struct BuildStrategy { ...@@ -148,8 +148,13 @@ struct BuildStrategy {
bool allow_cuda_graph_capture_{false}; bool allow_cuda_graph_capture_{false};
// Inference pass // Inference pass
bool inference_{false}; // switch for infernce pass bool enable_inference_pass_{false}; // switch for infernce pass
bool del_dropout_{false}; bool delete_dropout_{true}; // delte dropout op
#ifdef PADDLE_WITH_MKLDNN
bool use_mkldnn_{true}; // use mkdnn to do inference
#else
bool use_mkldnn_{false}; // use mkdnn to do inference
#endif
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if // num_trainers is 1, so the current fields of build_strategy doesn't tell if
......
...@@ -74,8 +74,7 @@ PEEngine::PEEngine(const std::shared_ptr<FunctionInfo> &info, ...@@ -74,8 +74,7 @@ PEEngine::PEEngine(const std::shared_ptr<FunctionInfo> &info,
void PEEngine::CreateGraphAndPE() { void PEEngine::CreateGraphAndPE() {
framework::details::BuildStrategy build_strategy; framework::details::BuildStrategy build_strategy;
build_strategy.inference_ = true; build_strategy.enable_inference_pass_ = true; // use pe to inference
build_strategy.del_dropout_ = true;
auto execution_strategy = GetExecutionStrategy(place_); auto execution_strategy = GetExecutionStrategy(place_);
auto &program_desc = info_->ProgramDesc(); auto &program_desc = info_->ProgramDesc();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册