From 17c751bec6ac73ab84e4189c198ecacc7f5c9eb2 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 2 Jul 2020 05:37:45 +0200 Subject: [PATCH] [oneDNN] Fix to #25078 (#25256) --- .../ir/mkldnn/mkldnn_inplace_pass.cc | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc index 9ca08ff777b..7bd94bf55ea 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc @@ -66,17 +66,17 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { return; } - VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") " + VLOG(3) << "oneDNN Inplace op(" << current_op->id() << ") " << "Curr Node In: " << current_op_in->Name() << " Curr Node out: " << current_op_out->Name(); - VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") " + VLOG(3) << "oneDNN Inplace next op(" << next_op->id() << ") " << " next Node out: " << next_op_out->Name(); auto inputs = current_op->Op()->Inputs(); auto outputs = current_op->Op()->Outputs(); auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN - VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") " + VLOG(3) << "oneDNN InplaceInferer op(" << current_op->id() << ") " << in_to_outs.begin()->first << ": " << inputs[in_to_outs.begin()->first][0] << " " << in_to_outs.begin()->second << ": " @@ -85,7 +85,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { auto inplace_input_vec = inputs[in_to_outs.begin()->first]; if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(), current_op_in->Name()) == inplace_input_vec.end()) { - VLOG(3) << "DNNL in-place pass SKIP pattern "; + VLOG(3) << "oneDNN in-place pass SKIP pattern "; return; } @@ -93,7 +93,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { // is used anywhere else apart from inplaced op auto input_consumers = current_op_in->outputs; if (input_consumers.size() > 1) { - VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " + VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot " "be an input to multiple operators"; return; } else { @@ -106,7 +106,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { if ((n->id() != current_op_in->id()) && (n->id() != next_op_out->id()) && (n->Name() == current_op_in->Name())) { - VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of " + VLOG(3) << "oneDNN in-place pass FAIL var used in diffrent part of " "graph "; return; } @@ -122,7 +122,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { original_output_names[current_op->Name() + current_op_in->Name()] = current_op_out->Name(); } else { - VLOG(3) << "DNNL Inplace: Current op already inplaced! "; + VLOG(3) << "oneDNN Inplace: Current op already inplaced! "; } // It may be that next op is reusing some of vars, we need to @@ -133,7 +133,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { if ((n_op_infer_inplace == nullptr)) { for (auto& m : n->outputs) { if (m->Name() == current_op_in->Name()) { - VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " + VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot " "be an output to non-inplaced next op"; return; } @@ -173,7 +173,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { (std::find(next_op_inplace_inputs.begin(), next_op_inplace_inputs.end(), original_name) != next_op_inplace_inputs.end())) { - VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its " + VLOG(3) << "oneDNN InPlace: Next Op is in-placed , updating its " "input " "and output var!"; next_op->Op()->SetOutput( @@ -190,10 +190,24 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { next_op->Op()->RenameInput(original_name, current_op_out->Name()); found_inplace_count++; - VLOG(3) << "DNNL InPlace applied!"; + VLOG(3) << "oneDNN InPlace applied!"; }; - gpd(graph, handler); + // TODO(jczaja): inplace pass does not influece ops inside block ops + auto should_inplace = [&](Graph* g) { + std::unordered_set unwanted_ops( + {"conditional_block", "While", "while_loop"}); + for (auto& node : g->Nodes()) { + if (node->IsOp() && + unwanted_ops.find(node->Name()) != unwanted_ops.end()) { + VLOG(3) << "oneDNN InPlace FAILED: unsupported op: " << node->Name(); + return false; + } + } + return true; + }; + + if (should_inplace(graph)) gpd(graph, handler); } } // namespace ir -- GitLab