未验证 提交 bfb8a642 编写于 作者: A alncat 提交者: GitHub

updated conv bn fuse pass to make it compatible with latest batch_norm op (#31272)

上级 a37658da
...@@ -790,27 +790,31 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, ...@@ -790,27 +790,31 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
auto *bn_scale_var = pattern->NewNode(bn_scale_repr()) auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale"); ->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1);
// BN Bias // BN Bias
auto *bn_bias_var = pattern->NewNode(bn_bias_repr()) auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias"); ->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1);
// BN Mean // BN Mean
auto *bn_mean_var = pattern->NewNode(bn_mean_repr()) auto *bn_mean_var = pattern->NewNode(bn_mean_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean"); ->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1);
// BN Variance // BN Variance
auto *bn_variance_var = pattern->NewNode(bn_variance_repr()) auto *bn_variance_var = pattern->NewNode(bn_variance_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance"); ->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1);
// BN output // BN output
auto *bn_out_var = pattern->NewNode(bn_out_repr()) auto *bn_out_var = pattern->NewNode(bn_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("batch_norm"); ->assert_is_op_output("batch_norm", "Y");
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->AsOutput() ->AsOutput()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册