diff --git a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc index 3d60148c170f94fe2019f719554fa1b2c27ea783..96f115b28225077efbb8ceb753775a8926e8f496 100644 --- a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc @@ -34,7 +34,16 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { detector.mutable_pattern() ->NewNode("scale_in") ->assert_is_op_input("scale") - ->assert_more([](Node* x) { return x->outputs.size() == 1UL; }); + ->assert_has_n_outputs(1) + ->assert_more([](Node* x) { + for (auto* op : x->inputs) { + auto op_type = op->Op()->Type(); + if (op_type == "conditional_block" || op_type == "while") { + return false; + } + } + return true; + }); auto scale_op = detector.mutable_pattern() ->NewNode("scale_fuse") ->assert_is_op("scale")