未验证 提交 0a7e6f90 编写于 作者: H Hui Zhang 提交者: GitHub

[jit] pe engine with mkldnn (#45728)

* using mkldnn

* using with mkldnn macro

* fix use mkldnn
上级 1967c6a6
......@@ -170,8 +170,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendOpFusePasses() {
// 1. infernce pass if enabled.
AppendPassWithCheck(strategy_.inference_ && strategy_.del_dropout_,
"delete_dropout_op_x_pass");
AppendPassWithCheck(
strategy_.enable_inference_pass_ && strategy_.delete_dropout_,
"delete_dropout_op_x_pass");
AppendPassWithCheck(
strategy_.enable_inference_pass_ && strategy_.use_mkldnn_,
"mkldnn_placement_pass");
// 2. trainning pass
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
......
......@@ -148,8 +148,13 @@ struct BuildStrategy {
bool allow_cuda_graph_capture_{false};
// Inference pass
bool inference_{false}; // switch for infernce pass
bool del_dropout_{false};
bool enable_inference_pass_{false}; // switch for infernce pass
bool delete_dropout_{true}; // delte dropout op
#ifdef PADDLE_WITH_MKLDNN
bool use_mkldnn_{true}; // use mkdnn to do inference
#else
bool use_mkldnn_{false}; // use mkdnn to do inference
#endif
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
......
......@@ -74,8 +74,7 @@ PEEngine::PEEngine(const std::shared_ptr<FunctionInfo> &info,
void PEEngine::CreateGraphAndPE() {
framework::details::BuildStrategy build_strategy;
build_strategy.inference_ = true;
build_strategy.del_dropout_ = true;
build_strategy.enable_inference_pass_ = true; // use pe to inference
auto execution_strategy = GetExecutionStrategy(place_);
auto &program_desc = info_->ProgramDesc();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册