diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc index 6590ef44f89626bdb9574a61ae8b5ced3fdd52d6..df2bc5af8b67924737c80a3df200f0bca2b6f2a4 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc @@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument( "Pointer to graph argument should not be NULL.")); std::unordered_map original_output_names; + std::unordered_set inplaced_vars; GraphPatternDetector gpd; patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(), "mkldnn_inplace"}; @@ -94,6 +95,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " "be an input to multiple operators"; return; + } else { + // We will prevent in-place when + // input is used in other part of graph, unless it was a result of + // inplacing + // Allow to next op out reuse inpuit var, as this is the same chaing + if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) { + for (const Node* n : graph->Nodes()) { + if ((n->id() != current_op_in->id()) && + (n->id() != next_op_out->id()) && + (n->Name() == current_op_in->Name())) { + VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of " + "graph "; + return; + } + } + } } // If this op was alrady inplaced in previous pass placements @@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { auto out_name = in_to_outs.begin()->second; current_op->Op()->SetOutput( out_name, std::vector({current_op_out->Name()})); + // Record var name + inplaced_vars.insert(current_op_out->Name()); // If next op in a line is doing inplace // then we need to update its output as well