未验证 提交 17c751be 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to #25078 (#25256)

上级 3b8f0a64
...@@ -66,17 +66,17 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -66,17 +66,17 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
return; return;
} }
VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") " VLOG(3) << "oneDNN Inplace op(" << current_op->id() << ") "
<< "Curr Node In: " << current_op_in->Name() << "Curr Node In: " << current_op_in->Name()
<< " Curr Node out: " << current_op_out->Name(); << " Curr Node out: " << current_op_out->Name();
VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") " VLOG(3) << "oneDNN Inplace next op(" << next_op->id() << ") "
<< " next Node out: " << next_op_out->Name(); << " next Node out: " << next_op_out->Name();
auto inputs = current_op->Op()->Inputs(); auto inputs = current_op->Op()->Inputs();
auto outputs = current_op->Op()->Outputs(); auto outputs = current_op->Op()->Outputs();
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") " VLOG(3) << "oneDNN InplaceInferer op(" << current_op->id() << ") "
<< in_to_outs.begin()->first << ": " << in_to_outs.begin()->first << ": "
<< inputs[in_to_outs.begin()->first][0] << " " << inputs[in_to_outs.begin()->first][0] << " "
<< in_to_outs.begin()->second << ": " << in_to_outs.begin()->second << ": "
...@@ -85,7 +85,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -85,7 +85,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
auto inplace_input_vec = inputs[in_to_outs.begin()->first]; auto inplace_input_vec = inputs[in_to_outs.begin()->first];
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(), if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
current_op_in->Name()) == inplace_input_vec.end()) { current_op_in->Name()) == inplace_input_vec.end()) {
VLOG(3) << "DNNL in-place pass SKIP pattern "; VLOG(3) << "oneDNN in-place pass SKIP pattern ";
return; return;
} }
...@@ -93,7 +93,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -93,7 +93,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
// is used anywhere else apart from inplaced op // is used anywhere else apart from inplaced op
auto input_consumers = current_op_in->outputs; auto input_consumers = current_op_in->outputs;
if (input_consumers.size() > 1) { if (input_consumers.size() > 1) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
"be an input to multiple operators"; "be an input to multiple operators";
return; return;
} else { } else {
...@@ -106,7 +106,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -106,7 +106,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
if ((n->id() != current_op_in->id()) && if ((n->id() != current_op_in->id()) &&
(n->id() != next_op_out->id()) && (n->id() != next_op_out->id()) &&
(n->Name() == current_op_in->Name())) { (n->Name() == current_op_in->Name())) {
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of " VLOG(3) << "oneDNN in-place pass FAIL var used in diffrent part of "
"graph "; "graph ";
return; return;
} }
...@@ -122,7 +122,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -122,7 +122,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
original_output_names[current_op->Name() + current_op_in->Name()] = original_output_names[current_op->Name() + current_op_in->Name()] =
current_op_out->Name(); current_op_out->Name();
} else { } else {
VLOG(3) << "DNNL Inplace: Current op already inplaced! "; VLOG(3) << "oneDNN Inplace: Current op already inplaced! ";
} }
// It may be that next op is reusing some of vars, we need to // It may be that next op is reusing some of vars, we need to
...@@ -133,7 +133,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -133,7 +133,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
if ((n_op_infer_inplace == nullptr)) { if ((n_op_infer_inplace == nullptr)) {
for (auto& m : n->outputs) { for (auto& m : n->outputs) {
if (m->Name() == current_op_in->Name()) { if (m->Name() == current_op_in->Name()) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
"be an output to non-inplaced next op"; "be an output to non-inplaced next op";
return; return;
} }
...@@ -173,7 +173,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -173,7 +173,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
(std::find(next_op_inplace_inputs.begin(), (std::find(next_op_inplace_inputs.begin(),
next_op_inplace_inputs.end(), next_op_inplace_inputs.end(),
original_name) != next_op_inplace_inputs.end())) { original_name) != next_op_inplace_inputs.end())) {
VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its " VLOG(3) << "oneDNN InPlace: Next Op is in-placed , updating its "
"input " "input "
"and output var!"; "and output var!";
next_op->Op()->SetOutput( next_op->Op()->SetOutput(
...@@ -190,10 +190,24 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -190,10 +190,24 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
next_op->Op()->RenameInput(original_name, current_op_out->Name()); next_op->Op()->RenameInput(original_name, current_op_out->Name());
found_inplace_count++; found_inplace_count++;
VLOG(3) << "DNNL InPlace applied!"; VLOG(3) << "oneDNN InPlace applied!";
}; };
gpd(graph, handler); // TODO(jczaja): inplace pass does not influece ops inside block ops
auto should_inplace = [&](Graph* g) {
std::unordered_set<std::string> unwanted_ops(
{"conditional_block", "While", "while_loop"});
for (auto& node : g->Nodes()) {
if (node->IsOp() &&
unwanted_ops.find(node->Name()) != unwanted_ops.end()) {
VLOG(3) << "oneDNN InPlace FAILED: unsupported op: " << node->Name();
return false;
}
}
return true;
};
if (should_inplace(graph)) gpd(graph, handler);
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册