diff --git a/paddle/fluid/framework/ir/identity_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_op_clean_pass.cc index 255d1b06ea0f62d9972ec2a789630f06b2343030..152ac0860366b9be3c79c10e7a2e69403169777f 100644 --- a/paddle/fluid/framework/ir/identity_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_op_clean_pass.cc @@ -124,7 +124,14 @@ FindTwoCastOpPattern::FindTwoCastOpPattern(PDPattern* pattern, }); auto* cast_op_1 = pattern->NewNode(cast_op_1_repr())->assert_is_op("cast"); - auto* cast_op_1_out = pattern->NewNode(cast_op_1_out_repr())->assert_is_var(); + auto* cast_op_1_out = pattern->NewNode(cast_op_1_out_repr()) + ->assert_is_var() + ->assert_is_op_output("cast", "Out") + ->assert_more([](Node* x) { + const auto& var_type = x->Var()->GetDataType(); + return var_type != proto::VarType::INT32 && + var_type != proto::VarType::INT64; + }); auto* cast_op_2 = pattern->NewNode(cast_op_2_repr())->assert_is_op("cast"); auto* cast_op_2_out = pattern->NewNode(cast_op_2_out_repr())->assert_is_var();