From 67e47e693cfb32dad0c1834f177c31ac7556438e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 26 Dec 2017 11:30:15 +0800 Subject: [PATCH] refine batch_norm --- paddle/operators/batch_norm_op.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/operators/batch_norm_op.cc b/paddle/operators/batch_norm_op.cc index 49cb0fa4d9d..98db28ddee7 100644 --- a/paddle/operators/batch_norm_op.cc +++ b/paddle/operators/batch_norm_op.cc @@ -50,10 +50,6 @@ class BatchNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), ""); PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), ""); - const float epsilon = ctx->Attrs().Get("epsilon"); - PADDLE_ENFORCE_GE(epsilon, 0.0, "epsilon should be larger than 0"); - PADDLE_ENFORCE_LE(epsilon, 0.001, "epsilon should not be too large"); - // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0], "Mean and MeanOut should share the same memory"); @@ -91,7 +87,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddAttr("is_test", "").SetDefault(false); AddAttr("momentum", "").SetDefault(0.9); - AddAttr("epsilon", "").SetDefault(1e-5); + AddAttr("epsilon", "") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); AddAttr("data_layout", "").SetDefault("NCHW"); AddInput("X", "The input tensor"); AddInput("Scale", -- GitLab