diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc index d7ab51fa73f1938d8fb9c4720ad128138492924d..9ca08ff777ba7d6032bffb6c358c030b1cd1366c 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc @@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument( "Pointer to graph argument should not be NULL.")); std::unordered_map original_output_names; + std::unordered_set inplaced_vars; GraphPatternDetector gpd; patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(), "mkldnn_inplace"}; @@ -95,6 +96,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " "be an input to multiple operators"; return; + } else { + // We will prevent in-place when + // input is used in other part of graph, unless it was a result of + // inplacing + // Allow to next op out reuse inpuit var, as this is the same chaing + if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) { + for (const Node* n : graph->Nodes()) { + if ((n->id() != current_op_in->id()) && + (n->id() != next_op_out->id()) && + (n->Name() == current_op_in->Name())) { + VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of " + "graph "; + return; + } + } + } } // If this op was alrady inplaced in previous pass placements @@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { auto out_name = in_to_outs.begin()->second; current_op->Op()->SetOutput( out_name, std::vector({current_op_out->Name()})); + // Record var name + inplaced_vars.insert(current_op_out->Name()); // If next op in a line is doing inplace // then we need to update its output as well diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index d5dfcf80bee70abbaeedbd8a970eb5a474f21c7c..c07ac11e278901e9b9475492ca38411dcf8184d3 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -200,8 +200,8 @@ void CpuPassStrategy::EnableMKLDNN() { "matmul_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", - // "mkldnn_inplace_pass", // This pass should be activated after - // fuses + "mkldnn_inplace_pass", // This pass should be activated after + // fuses })) { passes_.push_back(pass); }