From 5ace20fc3f7b986abe710c8058fca24b46c31b2e Mon Sep 17 00:00:00 2001 From: alncat Date: Wed, 27 Jan 2021 14:58:33 +0800 Subject: [PATCH] modified conv+bn fuse pass to fix wrong mask in mask rcnn (#30704) --- paddle/fluid/framework/ir/conv_bn_fuse_pass.cc | 13 ++++++++++++- .../framework/ir/graph_pattern_detector.cc | 17 ++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index a232f7ebb89..0801ecf1a5f 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -95,13 +95,24 @@ void recompute_bias_and_weights(const Scope* scope, variance_array += epsilon; variance_array = variance_array.sqrt(); variance_array = scale_array / variance_array; - + for (int i = 0; i < variance_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ( + isfinite(variance_array[i]), true, + platform::errors::InvalidArgument("fuse batch norm variance should be " + "finite. Found nonfinite values!")); + } EigenVectorArrayMap eltwise_y_in_array( eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), eltwise_y_in_tensor->numel(), 1); eltwise_y_in_array = ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array; + for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) { + PADDLE_ENFORCE_EQ( + isfinite(eltwise_y_in_array[i]), true, + platform::errors::InvalidArgument("fused batch norm bias should be " + "finite. Found nonfinite values!")); + } // Re-compute weight of conv2d from BN auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 185f6454ca7..173734cb0da 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -824,22 +824,25 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "MeanOut"); + ->assert_is_op_output("batch_norm", "MeanOut") + ->assert_has_n_outputs(0); auto *bn_variance_out_var = pattern->NewNode(bn_variance_out_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "VarianceOut"); + ->assert_is_op_output("batch_norm", "VarianceOut") + ->assert_has_n_outputs(0); - auto *bn_saved_mean_var = - pattern->NewNode(bn_saved_mean_repr()) - ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedMean"); + auto *bn_saved_mean_var = pattern->NewNode(bn_saved_mean_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedMean") + ->assert_has_n_outputs(0); auto *bn_saved_variance_var = pattern->NewNode(bn_saved_variance_repr()) ->AsOutput() - ->assert_is_op_output("batch_norm", "SavedVariance"); + ->assert_is_op_output("batch_norm", "SavedVariance") + ->assert_has_n_outputs(0); conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); -- GitLab