From b4be9717bfa493abe10c67080ced3c21982b64ec Mon Sep 17 00:00:00 2001 From: alncat Date: Tue, 2 Feb 2021 19:26:22 +0800 Subject: [PATCH] Conv bn fuse fix (#30830) * fixed compilation error on gcc 4.8.x due to the usage of isfinite (#30733) * modified conv+bn fuse pass to fix wrong mask in mask rcnn (#30704) --- paddle/fluid/framework/ir/conv_bn_fuse_pass.cc | 17 ++++++++++++++++- .../framework/ir/graph_pattern_detector.cc | 17 ++++++++++------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index a232f7ebb89..1eee7c01f48 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -95,13 +95,28 @@ void recompute_bias_and_weights(const Scope* scope, variance_array += epsilon; variance_array = variance_array.sqrt(); variance_array = scale_array / variance_array; - + for (int i = 0; i < variance_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ(std::isfinite(variance_array[i]), true, + platform::errors::InvalidArgument( + "The inverse of Fused batch norm variance " + "should be finite. Found nonfinite values! " + "Please check %s ", + bn_variance.Name())); + } EigenVectorArrayMap eltwise_y_in_array( eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), eltwise_y_in_tensor->numel(), 1); eltwise_y_in_array = ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array; + for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ(std::isfinite(eltwise_y_in_array[i]), true, + platform::errors::InvalidArgument( + "Fused batch norm bias should be " + "finite. Found nonfinite values! " + "Please check %s and related variables.", + bn_variance.Name())); + } // Re-compute weight of conv2d from BN auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4efdfb01e38..dd8e942bdea 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -824,22 +824,25 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "MeanOut"); + ->assert_is_op_output("batch_norm", "MeanOut") + ->assert_has_n_outputs(0); auto *bn_variance_out_var = pattern->NewNode(bn_variance_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "VarianceOut"); + ->assert_is_op_output("batch_norm", "VarianceOut") + ->assert_has_n_outputs(0); - auto *bn_saved_mean_var = - pattern->NewNode(bn_saved_mean_repr()) - ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedMean"); + auto *bn_saved_mean_var = pattern->NewNode(bn_saved_mean_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedMean") + ->assert_has_n_outputs(0); auto *bn_saved_variance_var = pattern->NewNode(bn_saved_variance_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedVariance"); + ->assert_is_op_output("batch_norm", "SavedVariance") + ->assert_has_n_outputs(0); conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); -- GitLab