未验证 提交 5ace20fc 编写于 作者: A alncat 提交者: GitHub

modified conv+bn fuse pass to fix wrong mask in mask rcnn (#30704)

上级 824a79d3
......@@ -95,13 +95,24 @@ void recompute_bias_and_weights(const Scope* scope,
variance_array += epsilon;
variance_array = variance_array.sqrt();
variance_array = scale_array / variance_array;
for (int i = 0; i < variance_tensor->numel(); i++) {
PADDLE_ENFORCE_EQ(
isfinite(variance_array[i]), true,
platform::errors::InvalidArgument("fuse batch norm variance should be "
"finite. Found nonfinite values!"));
}
EigenVectorArrayMap eltwise_y_in_array(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 1);
eltwise_y_in_array =
((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) {
PADDLE_ENFORCE_EQ(
isfinite(eltwise_y_in_array[i]), true,
platform::errors::InvalidArgument("fused batch norm bias should be "
"finite. Found nonfinite values!"));
}
// Re-compute weight of conv2d from BN
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
......
......@@ -824,22 +824,25 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "MeanOut");
->assert_is_op_output("batch_norm", "MeanOut")
->assert_has_n_outputs(0);
auto *bn_variance_out_var =
pattern->NewNode(bn_variance_out_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "VarianceOut");
->assert_is_op_output("batch_norm", "VarianceOut")
->assert_has_n_outputs(0);
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "SavedMean");
auto *bn_saved_mean_var = pattern->NewNode(bn_saved_mean_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "SavedMean")
->assert_has_n_outputs(0);
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "SavedVariance");
->assert_is_op_output("batch_norm", "SavedVariance")
->assert_has_n_outputs(0);
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册