未验证 提交 8b88cd51 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to inplace pass (#24442)

* - Disabling inplace pass

test=develop

- Disable cycles

test=develop

 - fix

test=develop

- Enhancement to in-place

- Lint fixes

test=develop

* - Lint fixes

test=develop
上级 f1c4c14c
......@@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
std::unordered_map<std::string, std::string> original_output_names;
std::unordered_set<std::string> inplaced_vars;
GraphPatternDetector gpd;
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
"mkldnn_inplace"};
......@@ -95,6 +96,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
"be an input to multiple operators";
return;
} else {
// We will prevent in-place when
// input is used in other part of graph, unless it was a result of
// inplacing
// Allow to next op out reuse inpuit var, as this is the same chaing
if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) {
for (const Node* n : graph->Nodes()) {
if ((n->id() != current_op_in->id()) &&
(n->id() != next_op_out->id()) &&
(n->Name() == current_op_in->Name())) {
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of "
"graph ";
return;
}
}
}
}
// If this op was alrady inplaced in previous pass placements
......@@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
auto out_name = in_to_outs.begin()->second;
current_op->Op()->SetOutput(
out_name, std::vector<std::string>({current_op_out->Name()}));
// Record var name
inplaced_vars.insert(current_op_out->Name());
// If next op in a line is doing inplace
// then we need to update its output as well
......
......@@ -200,7 +200,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "mkldnn_inplace_pass", // This pass should be activated after
"mkldnn_inplace_pass", // This pass should be activated after
// fuses
})) {
passes_.push_back(pass);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册