未验证 提交 bfb8a642 编写于 作者: A alncat 提交者: GitHub

updated conv bn fuse pass to make it compatible with latest batch_norm op (#31272)

上级 a37658da
......@@ -790,27 +790,31 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale");
->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1);
// BN Bias
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias");
->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1);
// BN Mean
auto *bn_mean_var = pattern->NewNode(bn_mean_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean");
->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1);
// BN Variance
auto *bn_variance_var = pattern->NewNode(bn_variance_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance");
->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1);
// BN output
auto *bn_out_var = pattern->NewNode(bn_out_repr())
->AsOutput()
->assert_is_op_output("batch_norm");
->assert_is_op_output("batch_norm", "Y");
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->AsOutput()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册