From 8b88cd5167aae149db4fff0525522f1031a567fe Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Wed, 13 May 2020 03:46:43 +0200 Subject: [PATCH] [oneDNN] Fix to inplace pass (#24442) * - Disabling inplace pass test=develop - Disable cycles test=develop - fix test=develop - Enhancement to in-place - Lint fixes test=develop * - Lint fixes test=develop --- .../ir/mkldnn/mkldnn_inplace_pass.cc | 19 +++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc index d7ab51fa73f..9ca08ff777b 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc @@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument( "Pointer to graph argument should not be NULL.")); std::unordered_map original_output_names; + std::unordered_set inplaced_vars; GraphPatternDetector gpd; patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(), "mkldnn_inplace"}; @@ -95,6 +96,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " "be an input to multiple operators"; return; + } else { + // We will prevent in-place when + // input is used in other part of graph, unless it was a result of + // inplacing + // Allow to next op out reuse inpuit var, as this is the same chaing + if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) { + for (const Node* n : graph->Nodes()) { + if ((n->id() != current_op_in->id()) && + (n->id() != next_op_out->id()) && + (n->Name() == current_op_in->Name())) { + VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of " + "graph "; + return; + } + } + } } // If this op was alrady inplaced in previous pass placements @@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { auto out_name = in_to_outs.begin()->second; current_op->Op()->SetOutput( out_name, std::vector({current_op_out->Name()})); + // Record var name + inplaced_vars.insert(current_op_out->Name()); // If next op in a line is doing inplace // then we need to update its output as well diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index d5dfcf80bee..c07ac11e278 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -200,8 +200,8 @@ void CpuPassStrategy::EnableMKLDNN() { "matmul_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", - // "mkldnn_inplace_pass", // This pass should be activated after - // fuses + "mkldnn_inplace_pass", // This pass should be activated after + // fuses })) { passes_.push_back(pass); } -- GitLab