未验证 提交 772746c0 编写于 作者: L lidanqing 提交者: GitHub

[oneDNN] Fix to inplace pass (#24442) (#25182)

上级 ddc7f39e
......@@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
std::unordered_map<std::string, std::string> original_output_names;
std::unordered_set<std::string> inplaced_vars;
GraphPatternDetector gpd;
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
"mkldnn_inplace"};
......@@ -94,6 +95,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
"be an input to multiple operators";
return;
} else {
// We will prevent in-place when
// input is used in other part of graph, unless it was a result of
// inplacing
// Allow to next op out reuse inpuit var, as this is the same chaing
if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) {
for (const Node* n : graph->Nodes()) {
if ((n->id() != current_op_in->id()) &&
(n->id() != next_op_out->id()) &&
(n->Name() == current_op_in->Name())) {
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of "
"graph ";
return;
}
}
}
}
// If this op was alrady inplaced in previous pass placements
......@@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
auto out_name = in_to_outs.begin()->second;
current_op->Op()->SetOutput(
out_name, std::vector<std::string>({current_op_out->Name()}));
// Record var name
inplaced_vars.insert(current_op_out->Name());
// If next op in a line is doing inplace
// then we need to update its output as well
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册