未验证 提交 65777601 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #5692 from peterzhang2029/add_bn_eq

Make epsilon in BatchNormLayer a configurable variable.
...@@ -41,6 +41,7 @@ bool BatchNormBaseLayer::init(const LayerMap& layerMap, ...@@ -41,6 +41,7 @@ bool BatchNormBaseLayer::init(const LayerMap& layerMap,
useGlobalStats_ = config_.use_global_stats(); useGlobalStats_ = config_.use_global_stats();
} }
movingAvgFraction_ = config_.moving_average_fraction(); movingAvgFraction_ = config_.moving_average_fraction();
epsilon_ = config_.epsilon();
weight_.reset(new Weight(1, channels_, parameters_[0])); weight_.reset(new Weight(1, channels_, parameters_[0]));
movingMean_.reset(new Weight(1, channels_, parameters_[1])); movingMean_.reset(new Weight(1, channels_, parameters_[1]));
......
...@@ -94,6 +94,8 @@ protected: ...@@ -94,6 +94,8 @@ protected:
bool useGlobalStats_; bool useGlobalStats_;
// use to compute moving mean and variance. // use to compute moving mean and variance.
real movingAvgFraction_; real movingAvgFraction_;
// Epsilon is a small random noise used in batch normalization for stability.
real epsilon_;
}; };
} // namespace paddle } // namespace paddle
...@@ -22,8 +22,6 @@ namespace paddle { ...@@ -22,8 +22,6 @@ namespace paddle {
REGISTER_LAYER(batch_norm, BatchNormalizationLayer); REGISTER_LAYER(batch_norm, BatchNormalizationLayer);
const real BatchNormalizationLayer::EPS = 1E-5;
bool BatchNormalizationLayer::init(const LayerMap& layerMap, bool BatchNormalizationLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
/* Initialize the basic parent class */ /* Initialize the basic parent class */
...@@ -53,7 +51,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) { ...@@ -53,7 +51,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
calMovingMeanAndVar(); calMovingMeanAndVar();
savedInvVar_->subScalar(-EPS); savedInvVar_->subScalar(-epsilon_);
savedInvVar_->sqrt2(*savedInvVar_); savedInvVar_->sqrt2(*savedInvVar_);
} }
...@@ -74,7 +72,7 @@ void BatchNormalizationLayer::setMeanAndStd() { ...@@ -74,7 +72,7 @@ void BatchNormalizationLayer::setMeanAndStd() {
savedInvVar_->copyFrom(*(movingVar_->getW())); savedInvVar_->copyFrom(*(movingVar_->getW()));
savedInvVar_->downClip(real(0.0)); savedInvVar_->downClip(real(0.0));
savedInvVar_->subScalar(-EPS); savedInvVar_->subScalar(-epsilon_);
savedInvVar_->sqrt2(*savedInvVar_); savedInvVar_->sqrt2(*savedInvVar_);
} }
......
...@@ -39,9 +39,6 @@ public: ...@@ -39,9 +39,6 @@ public:
void backward(const UpdateCallback& callback = nullptr) override; void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
/// Epsilon value used in the batch normalization formula.
static const real EPS;
/// Load pre-calculated mean and std. /// Load pre-calculated mean and std.
void setMeanAndStd(); void setMeanAndStd();
......
...@@ -21,8 +21,6 @@ namespace paddle { ...@@ -21,8 +21,6 @@ namespace paddle {
REGISTER_LAYER(cudnn_batch_norm, CudnnBatchNormLayer); REGISTER_LAYER(cudnn_batch_norm, CudnnBatchNormLayer);
const double CudnnBatchNormLayer::EPS = 1E-5;
bool CudnnBatchNormLayer::init(const LayerMap& layerMap, bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
/* Initialize the basic parent class */ /* Initialize the basic parent class */
...@@ -61,6 +59,9 @@ void CudnnBatchNormLayer::forward(PassType passType) { ...@@ -61,6 +59,9 @@ void CudnnBatchNormLayer::forward(PassType passType) {
real* movingMean = movingMean_->getW()->getData(); real* movingMean = movingMean_->getW()->getData();
real* movingVar = movingVar_->getW()->getData(); real* movingVar = movingVar_->getW()->getData();
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
if (!useGlobalStats_) { if (!useGlobalStats_) {
REGISTER_TIMER_INFO("CudnnBatchFwTimer", getName().c_str()); REGISTER_TIMER_INFO("CudnnBatchFwTimer", getName().c_str());
real* savedMean = savedMean_->getData(); real* savedMean = savedMean_->getData();
...@@ -75,7 +76,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { ...@@ -75,7 +76,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
1.0 - movingAvgFraction_, 1.0 - movingAvgFraction_,
movingMean, movingMean,
movingVar, movingVar,
EPS, eps_,
savedMean, savedMean,
savedInvVar); savedInvVar);
} else { } else {
...@@ -90,7 +91,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { ...@@ -90,7 +91,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
beta, beta,
movingMean, movingMean,
movingVar, movingVar,
EPS); eps_);
} else { } else {
// There is a limitation in cudnn library. // There is a limitation in cudnn library.
// When the batch size is larger than 1024 in cuDNN v5.1, // When the batch size is larger than 1024 in cuDNN v5.1,
...@@ -101,7 +102,7 @@ void CudnnBatchNormLayer::forward(PassType passType) { ...@@ -101,7 +102,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
beta, beta,
movingMean, movingMean,
movingVar, movingVar,
EPS, eps_,
batchSize, batchSize,
channels_, channels_,
imageH_ * imageD_, imageH_ * imageD_,
...@@ -128,6 +129,9 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { ...@@ -128,6 +129,9 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
real* savedMean = savedMean_->getData(); real* savedMean = savedMean_->getData();
real* savedInvVar = savedInvVar_->getData(); real* savedInvVar = savedInvVar_->getData();
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
auto create = [](MatrixPtr& m, size_t h, size_t w, real** p) { auto create = [](MatrixPtr& m, size_t h, size_t w, real** p) {
Matrix::resizeOrCreate(m, h, w, false, true); Matrix::resizeOrCreate(m, h, w, false, true);
m->zeroMem(); m->zeroMem();
...@@ -157,7 +161,7 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { ...@@ -157,7 +161,7 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
gamma, gamma,
gammaGrad, gammaGrad,
betaGrad, betaGrad,
EPS, eps_,
savedMean, savedMean,
savedInvVar); savedInvVar);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <cudnn.h>
#include "BatchNormBaseLayer.h" #include "BatchNormBaseLayer.h"
#include "Layer.h" #include "Layer.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
...@@ -46,12 +47,9 @@ public: ...@@ -46,12 +47,9 @@ public:
void backward(const UpdateCallback& callback = nullptr) override; void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
/** /// Epsilon value used in the batch normalization formula.
* Epsilon value used in the batch normalization formula. /// Same epsilon value should be used in forward and backward functions.
* Minimum allowed value is CUDNN_BN_MIN_EPSILON defined in cudnn.h. double eps_;
* Same epsilon value should be used in forward and backward functions.
*/
static const double EPS;
/// Input/output tensor descriptor desc /// Input/output tensor descriptor desc
hl_tensor_descriptor ioDesc_; hl_tensor_descriptor ioDesc_;
......
...@@ -21,8 +21,6 @@ namespace paddle { ...@@ -21,8 +21,6 @@ namespace paddle {
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer); REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
const real MKLDNNBatchNormLayer::EPS = 1E-5;
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap, bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
if (!MKLDNNLayer::init(layerMap, parameterMap)) { if (!MKLDNNLayer::init(layerMap, parameterMap)) {
...@@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap, ...@@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
useGlobalStats_ = config_.use_global_stats(); useGlobalStats_ = config_.use_global_stats();
} }
movingAvgFraction_ = config_.moving_average_fraction(); movingAvgFraction_ = config_.moving_average_fraction();
epsilon_ = config_.epsilon();
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use") VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
<< " --- global stats"; << " --- global stats";
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_; VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
...@@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD( ...@@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD(
if (wgt) { if (wgt) {
flags_ = (flags_ | batch_normalization_flag::use_scale_shift); flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
} }
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_); auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), epsilon_, flags_);
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_)); pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc()); CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
if (wgt) { if (wgt) {
...@@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD( ...@@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD(
} }
CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc()); CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc());
auto md = in->getMemoryDesc(); auto md = in->getMemoryDesc();
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_); auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, epsilon_, flags_);
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_)); pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc()); CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc()); CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc());
......
...@@ -32,7 +32,8 @@ protected: ...@@ -32,7 +32,8 @@ protected:
std::shared_ptr<bn_fwd::primitive_desc> fwdPD_; std::shared_ptr<bn_fwd::primitive_desc> fwdPD_;
// Epsilon value used in the batch normalization formula. // Epsilon value used in the batch normalization formula.
static const real EPS; real epsilon_;
// weight and bias in paddle // weight and bias in paddle
std::unique_ptr<Weight> weight_; std::unique_ptr<Weight> weight_;
std::unique_ptr<Weight> biases_; std::unique_ptr<Weight> biases_;
......
...@@ -540,6 +540,10 @@ message LayerConfig { ...@@ -540,6 +540,10 @@ message LayerConfig {
// for switch order layer // for switch order layer
optional ReshapeConfig reshape_conf = 59; optional ReshapeConfig reshape_conf = 59;
// for batch normalization layer
// The small constant added to the variance to improve numeric stability.
optional double epsilon = 60 [ default = 0.00001 ];
} }
message EvaluatorConfig { message EvaluatorConfig {
......
...@@ -2412,6 +2412,7 @@ class BatchNormLayer(LayerBase): ...@@ -2412,6 +2412,7 @@ class BatchNormLayer(LayerBase):
bias=True, bias=True,
img3D=False, img3D=False,
use_global_stats=True, use_global_stats=True,
epsilon=1e-5,
moving_average_fraction=0.9, moving_average_fraction=0.9,
batch_norm_type=None, batch_norm_type=None,
mean_var_names=None, mean_var_names=None,
...@@ -2460,6 +2461,9 @@ class BatchNormLayer(LayerBase): ...@@ -2460,6 +2461,9 @@ class BatchNormLayer(LayerBase):
self.config.use_global_stats = use_global_stats self.config.use_global_stats = use_global_stats
if moving_average_fraction is not None: if moving_average_fraction is not None:
self.config.moving_average_fraction = moving_average_fraction self.config.moving_average_fraction = moving_average_fraction
if epsilon is not None:
assert epsilon >= 1e-5, "epsilon must be no less than 1e-5."
self.config.epsilon = epsilon
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf image_conf = self.config.inputs[0].image_conf
......
...@@ -3118,6 +3118,7 @@ def batch_norm_layer(input, ...@@ -3118,6 +3118,7 @@ def batch_norm_layer(input,
param_attr=None, param_attr=None,
layer_attr=None, layer_attr=None,
batch_norm_type=None, batch_norm_type=None,
epsilon=1e-5,
moving_average_fraction=0.9, moving_average_fraction=0.9,
use_global_stats=None, use_global_stats=None,
mean_var_names=None): mean_var_names=None):
...@@ -3188,6 +3189,8 @@ def batch_norm_layer(input, ...@@ -3188,6 +3189,8 @@ def batch_norm_layer(input,
will use the mean and variance of the current batch will use the mean and variance of the current batch
of test data. of test data.
:type use_global_stats: bool | None. :type use_global_stats: bool | None.
:param epsilon: The small constant added to the variance to improve numeric stability.
:type epsilon: float.
:param moving_average_fraction: Factor used in the moving average computation. :param moving_average_fraction: Factor used in the moving average computation.
:math:`runningMean = newMean*(1-factor) + runningMean*factor` :math:`runningMean = newMean*(1-factor) + runningMean*factor`
:type moving_average_fraction: float. :type moving_average_fraction: float.
...@@ -3205,6 +3208,7 @@ def batch_norm_layer(input, ...@@ -3205,6 +3208,7 @@ def batch_norm_layer(input,
assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \ assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \
(batch_norm_type == "mkldnn_batch_norm") or \ (batch_norm_type == "mkldnn_batch_norm") or \
(batch_norm_type == "cudnn_batch_norm") (batch_norm_type == "cudnn_batch_norm")
l = Layer( l = Layer(
name=name, name=name,
img3D=img3D, img3D=img3D,
...@@ -3214,6 +3218,7 @@ def batch_norm_layer(input, ...@@ -3214,6 +3218,7 @@ def batch_norm_layer(input,
type=LayerType.BATCH_NORM_LAYER, type=LayerType.BATCH_NORM_LAYER,
batch_norm_type=batch_norm_type, batch_norm_type=batch_norm_type,
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
epsilon=epsilon,
moving_average_fraction=moving_average_fraction, moving_average_fraction=moving_average_fraction,
use_global_stats=use_global_stats, use_global_stats=use_global_stats,
mean_var_names=mean_var_names, mean_var_names=mean_var_names,
......
...@@ -65,6 +65,7 @@ layers { ...@@ -65,6 +65,7 @@ layers {
height: 227 height: 227
width: 227 width: 227
depth: 1 depth: 1
epsilon: 1e-05
} }
layers { layers {
name: "__crmnorm_0__" name: "__crmnorm_0__"
......
...@@ -65,6 +65,7 @@ layers { ...@@ -65,6 +65,7 @@ layers {
height: 256 height: 256
width: 256 width: 256
depth: 1 depth: 1
epsilon: 1e-05
} }
layers { layers {
name: "__crmnorm_0__" name: "__crmnorm_0__"
......
...@@ -36,6 +36,7 @@ layers { ...@@ -36,6 +36,7 @@ layers {
height: 6 height: 6
width: 20 width: 20
depth: 3 depth: 3
epsilon: 1e-05
} }
parameters { parameters {
name: "___batch_norm_0__.w0" name: "___batch_norm_0__.w0"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册