From 8585279f549b85a59d0430260ed19ee2244e5779 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 9 Jun 2022 11:27:10 +0800 Subject: [PATCH] fix scale pass when "conditional_block" or "while" is before "scale" (#43323) --- .../framework/ir/identity_scale_op_clean_pass.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc index 3d60148c170..96f115b2822 100644 --- a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc @@ -34,7 +34,16 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { detector.mutable_pattern() ->NewNode("scale_in") ->assert_is_op_input("scale") - ->assert_more([](Node* x) { return x->outputs.size() == 1UL; }); + ->assert_has_n_outputs(1) + ->assert_more([](Node* x) { + for (auto* op : x->inputs) { + auto op_type = op->Op()->Type(); + if (op_type == "conditional_block" || op_type == "while") { + return false; + } + } + return true; + }); auto scale_op = detector.mutable_pattern() ->NewNode("scale_fuse") ->assert_is_op("scale") -- GitLab