未验证 提交 17c751be 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to #25078 (#25256)

上级 3b8f0a64
......@@ -66,17 +66,17 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
return;
}
VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") "
VLOG(3) << "oneDNN Inplace op(" << current_op->id() << ") "
<< "Curr Node In: " << current_op_in->Name()
<< " Curr Node out: " << current_op_out->Name();
VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") "
VLOG(3) << "oneDNN Inplace next op(" << next_op->id() << ") "
<< " next Node out: " << next_op_out->Name();
auto inputs = current_op->Op()->Inputs();
auto outputs = current_op->Op()->Outputs();
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") "
VLOG(3) << "oneDNN InplaceInferer op(" << current_op->id() << ") "
<< in_to_outs.begin()->first << ": "
<< inputs[in_to_outs.begin()->first][0] << " "
<< in_to_outs.begin()->second << ": "
......@@ -85,7 +85,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
auto inplace_input_vec = inputs[in_to_outs.begin()->first];
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
current_op_in->Name()) == inplace_input_vec.end()) {
VLOG(3) << "DNNL in-place pass SKIP pattern ";
VLOG(3) << "oneDNN in-place pass SKIP pattern ";
return;
}
......@@ -93,7 +93,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
// is used anywhere else apart from inplaced op
auto input_consumers = current_op_in->outputs;
if (input_consumers.size() > 1) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
"be an input to multiple operators";
return;
} else {
......@@ -106,7 +106,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
if ((n->id() != current_op_in->id()) &&
(n->id() != next_op_out->id()) &&
(n->Name() == current_op_in->Name())) {
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of "
VLOG(3) << "oneDNN in-place pass FAIL var used in diffrent part of "
"graph ";
return;
}
......@@ -122,7 +122,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
original_output_names[current_op->Name() + current_op_in->Name()] =
current_op_out->Name();
} else {
VLOG(3) << "DNNL Inplace: Current op already inplaced! ";
VLOG(3) << "oneDNN Inplace: Current op already inplaced! ";
}
// It may be that next op is reusing some of vars, we need to
......@@ -133,7 +133,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
if ((n_op_infer_inplace == nullptr)) {
for (auto& m : n->outputs) {
if (m->Name() == current_op_in->Name()) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
"be an output to non-inplaced next op";
return;
}
......@@ -173,7 +173,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
(std::find(next_op_inplace_inputs.begin(),
next_op_inplace_inputs.end(),
original_name) != next_op_inplace_inputs.end())) {
VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its "
VLOG(3) << "oneDNN InPlace: Next Op is in-placed , updating its "
"input "
"and output var!";
next_op->Op()->SetOutput(
......@@ -190,10 +190,24 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
next_op->Op()->RenameInput(original_name, current_op_out->Name());
found_inplace_count++;
VLOG(3) << "DNNL InPlace applied!";
VLOG(3) << "oneDNN InPlace applied!";
};
gpd(graph, handler);
// TODO(jczaja): inplace pass does not influece ops inside block ops
auto should_inplace = [&](Graph* g) {
std::unordered_set<std::string> unwanted_ops(
{"conditional_block", "While", "while_loop"});
for (auto& node : g->Nodes()) {
if (node->IsOp() &&
unwanted_ops.find(node->Name()) != unwanted_ops.end()) {
VLOG(3) << "oneDNN InPlace FAILED: unsupported op: " << node->Name();
return false;
}
}
return true;
};
if (should_inplace(graph)) gpd(graph, handler);
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册