diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index a232f7ebb890a8c6af8346cddaca88d470b438e2..0801ecf1a5f9861bec911cd66dda4279f3e9ee07 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -95,13 +95,24 @@ void recompute_bias_and_weights(const Scope* scope, variance_array += epsilon; variance_array = variance_array.sqrt(); variance_array = scale_array / variance_array; - + for (int i = 0; i < variance_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ( + isfinite(variance_array[i]), true, + platform::errors::InvalidArgument("fuse batch norm variance should be " + "finite. Found nonfinite values!")); + } EigenVectorArrayMap eltwise_y_in_array( eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), eltwise_y_in_tensor->numel(), 1); eltwise_y_in_array = ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array; + for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ( + isfinite(eltwise_y_in_array[i]), true, + platform::errors::InvalidArgument("fused batch norm bias should be " + "finite. Found nonfinite values!")); + } // Re-compute weight of conv2d from BN auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 185f6454ca7b3541da4163a2503fd26c52db3380..173734cb0da3bf6fb681cd3a2db90071aaed2f0f 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -824,22 +824,25 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "MeanOut"); + ->assert_is_op_output("batch_norm", "MeanOut") + ->assert_has_n_outputs(0); auto *bn_variance_out_var = pattern->NewNode(bn_variance_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "VarianceOut"); + ->assert_is_op_output("batch_norm", "VarianceOut") + ->assert_has_n_outputs(0); - auto *bn_saved_mean_var = - pattern->NewNode(bn_saved_mean_repr()) - ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedMean"); + auto *bn_saved_mean_var = pattern->NewNode(bn_saved_mean_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedMean") + ->assert_has_n_outputs(0); auto *bn_saved_variance_var = pattern->NewNode(bn_saved_variance_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedVariance"); + ->assert_is_op_output("batch_norm", "SavedVariance") + ->assert_has_n_outputs(0); conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});