未验证 提交 772746c0 编写于 作者: L lidanqing 提交者: GitHub

[oneDNN] Fix to inplace pass (#24442) (#25182)

上级 ddc7f39e
...@@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -32,6 +32,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL.")); "Pointer to graph argument should not be NULL."));
std::unordered_map<std::string, std::string> original_output_names; std::unordered_map<std::string, std::string> original_output_names;
std::unordered_set<std::string> inplaced_vars;
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(), patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
"mkldnn_inplace"}; "mkldnn_inplace"};
...@@ -94,6 +95,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -94,6 +95,22 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
"be an input to multiple operators"; "be an input to multiple operators";
return; return;
} else {
// We will prevent in-place when
// input is used in other part of graph, unless it was a result of
// inplacing
// Allow to next op out reuse inpuit var, as this is the same chaing
if (inplaced_vars.find(current_op_in->Name()) == inplaced_vars.end()) {
for (const Node* n : graph->Nodes()) {
if ((n->id() != current_op_in->id()) &&
(n->id() != next_op_out->id()) &&
(n->Name() == current_op_in->Name())) {
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of "
"graph ";
return;
}
}
}
} }
// If this op was alrady inplaced in previous pass placements // If this op was alrady inplaced in previous pass placements
...@@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -132,6 +149,8 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
auto out_name = in_to_outs.begin()->second; auto out_name = in_to_outs.begin()->second;
current_op->Op()->SetOutput( current_op->Op()->SetOutput(
out_name, std::vector<std::string>({current_op_out->Name()})); out_name, std::vector<std::string>({current_op_out->Name()}));
// Record var name
inplaced_vars.insert(current_op_out->Name());
// If next op in a line is doing inplace // If next op in a line is doing inplace
// then we need to update its output as well // then we need to update its output as well
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册