diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index a232f7ebb890a8c6af8346cddaca88d470b438e2..1eee7c01f488661b7b6fdcb535ceaa7b0c9a904d 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -95,13 +95,28 @@ void recompute_bias_and_weights(const Scope* scope, variance_array += epsilon; variance_array = variance_array.sqrt(); variance_array = scale_array / variance_array; - + for (int i = 0; i < variance_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ(std::isfinite(variance_array[i]), true, + platform::errors::InvalidArgument( + "The inverse of Fused batch norm variance " + "should be finite. Found nonfinite values! " + "Please check %s ", + bn_variance.Name())); + } EigenVectorArrayMap eltwise_y_in_array( eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), eltwise_y_in_tensor->numel(), 1); eltwise_y_in_array = ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array; + for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ(std::isfinite(eltwise_y_in_array[i]), true, + platform::errors::InvalidArgument( + "Fused batch norm bias should be " + "finite. Found nonfinite values! " + "Please check %s and related variables.", + bn_variance.Name())); + } // Re-compute weight of conv2d from BN auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4efdfb01e38d8cc3816e1a139bd752bdcf7fb301..dd8e942bdea2d4487a61cfd3a529dd36722cebb7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -824,22 +824,25 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "MeanOut"); + ->assert_is_op_output("batch_norm", "MeanOut") + ->assert_has_n_outputs(0); auto *bn_variance_out_var = pattern->NewNode(bn_variance_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "VarianceOut"); + ->assert_is_op_output("batch_norm", "VarianceOut") + ->assert_has_n_outputs(0); - auto *bn_saved_mean_var = - pattern->NewNode(bn_saved_mean_repr()) - ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedMean"); + auto *bn_saved_mean_var = pattern->NewNode(bn_saved_mean_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedMean") + ->assert_has_n_outputs(0); auto *bn_saved_variance_var = pattern->NewNode(bn_saved_variance_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedVariance"); + ->assert_is_op_output("batch_norm", "SavedVariance") + ->assert_has_n_outputs(0); conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});