Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
65777601
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
65777601
编写于
11月 22, 2017
作者:
C
Cao Ying
提交者:
GitHub
11月 22, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5692 from peterzhang2029/add_bn_eq
Make epsilon in BatchNormLayer a configurable variable.
上级
6ab78aee
90e05a4b
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
41 addition
and
24 deletion
+41
-24
paddle/gserver/layers/BatchNormBaseLayer.cpp
paddle/gserver/layers/BatchNormBaseLayer.cpp
+1
-0
paddle/gserver/layers/BatchNormBaseLayer.h
paddle/gserver/layers/BatchNormBaseLayer.h
+2
-0
paddle/gserver/layers/BatchNormalizationLayer.cpp
paddle/gserver/layers/BatchNormalizationLayer.cpp
+2
-4
paddle/gserver/layers/BatchNormalizationLayer.h
paddle/gserver/layers/BatchNormalizationLayer.h
+0
-3
paddle/gserver/layers/CudnnBatchNormLayer.cpp
paddle/gserver/layers/CudnnBatchNormLayer.cpp
+10
-6
paddle/gserver/layers/CudnnBatchNormLayer.h
paddle/gserver/layers/CudnnBatchNormLayer.h
+4
-6
paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
+4
-4
paddle/gserver/layers/MKLDNNBatchNormLayer.h
paddle/gserver/layers/MKLDNNBatchNormLayer.h
+2
-1
proto/ModelConfig.proto
proto/ModelConfig.proto
+4
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+4
-0
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+5
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
...config_helpers/tests/configs/protostr/img_layers.protostr
+1
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
..._helpers/tests/configs/protostr/img_trans_layers.protostr
+1
-0
python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
..._helpers/tests/configs/protostr/test_BatchNorm3D.protostr
+1
-0
未找到文件。
paddle/gserver/layers/BatchNormBaseLayer.cpp
浏览文件 @
65777601
...
...
@@ -41,6 +41,7 @@ bool BatchNormBaseLayer::init(const LayerMap& layerMap,
useGlobalStats_
=
config_
.
use_global_stats
();
}
movingAvgFraction_
=
config_
.
moving_average_fraction
();
epsilon_
=
config_
.
epsilon
();
weight_
.
reset
(
new
Weight
(
1
,
channels_
,
parameters_
[
0
]));
movingMean_
.
reset
(
new
Weight
(
1
,
channels_
,
parameters_
[
1
]));
...
...
paddle/gserver/layers/BatchNormBaseLayer.h
浏览文件 @
65777601
...
...
@@ -94,6 +94,8 @@ protected:
bool
useGlobalStats_
;
// use to compute moving mean and variance.
real
movingAvgFraction_
;
// Epsilon is a small random noise used in batch normalization for stability.
real
epsilon_
;
};
}
// namespace paddle
paddle/gserver/layers/BatchNormalizationLayer.cpp
浏览文件 @
65777601
...
...
@@ -22,8 +22,6 @@ namespace paddle {
REGISTER_LAYER
(
batch_norm
,
BatchNormalizationLayer
);
const
real
BatchNormalizationLayer
::
EPS
=
1E-5
;
bool
BatchNormalizationLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
/* Initialize the basic parent class */
...
...
@@ -53,7 +51,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
calMovingMeanAndVar
();
savedInvVar_
->
subScalar
(
-
EPS
);
savedInvVar_
->
subScalar
(
-
epsilon_
);
savedInvVar_
->
sqrt2
(
*
savedInvVar_
);
}
...
...
@@ -74,7 +72,7 @@ void BatchNormalizationLayer::setMeanAndStd() {
savedInvVar_
->
copyFrom
(
*
(
movingVar_
->
getW
()));
savedInvVar_
->
downClip
(
real
(
0.0
));
savedInvVar_
->
subScalar
(
-
EPS
);
savedInvVar_
->
subScalar
(
-
epsilon_
);
savedInvVar_
->
sqrt2
(
*
savedInvVar_
);
}
...
...
paddle/gserver/layers/BatchNormalizationLayer.h
浏览文件 @
65777601
...
...
@@ -39,9 +39,6 @@ public:
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
)
override
;
protected:
/// Epsilon value used in the batch normalization formula.
static
const
real
EPS
;
/// Load pre-calculated mean and std.
void
setMeanAndStd
();
...
...
paddle/gserver/layers/CudnnBatchNormLayer.cpp
浏览文件 @
65777601
...
...
@@ -21,8 +21,6 @@ namespace paddle {
REGISTER_LAYER
(
cudnn_batch_norm
,
CudnnBatchNormLayer
);
const
double
CudnnBatchNormLayer
::
EPS
=
1E-5
;
bool
CudnnBatchNormLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
/* Initialize the basic parent class */
...
...
@@ -61,6 +59,9 @@ void CudnnBatchNormLayer::forward(PassType passType) {
real
*
movingMean
=
movingMean_
->
getW
()
->
getData
();
real
*
movingVar
=
movingVar_
->
getW
()
->
getData
();
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
eps_
=
std
::
max
(
CUDNN_BN_MIN_EPSILON
,
static_cast
<
double
>
(
epsilon_
));
if
(
!
useGlobalStats_
)
{
REGISTER_TIMER_INFO
(
"CudnnBatchFwTimer"
,
getName
().
c_str
());
real
*
savedMean
=
savedMean_
->
getData
();
...
...
@@ -75,7 +76,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
1.0
-
movingAvgFraction_
,
movingMean
,
movingVar
,
EPS
,
eps_
,
savedMean
,
savedInvVar
);
}
else
{
...
...
@@ -90,7 +91,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
beta
,
movingMean
,
movingVar
,
EPS
);
eps_
);
}
else
{
// There is a limitation in cudnn library.
// When the batch size is larger than 1024 in cuDNN v5.1,
...
...
@@ -101,7 +102,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
beta
,
movingMean
,
movingVar
,
EPS
,
eps_
,
batchSize
,
channels_
,
imageH_
*
imageD_
,
...
...
@@ -128,6 +129,9 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
real
*
savedMean
=
savedMean_
->
getData
();
real
*
savedInvVar
=
savedInvVar_
->
getData
();
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
eps_
=
std
::
max
(
CUDNN_BN_MIN_EPSILON
,
static_cast
<
double
>
(
epsilon_
));
auto
create
=
[](
MatrixPtr
&
m
,
size_t
h
,
size_t
w
,
real
**
p
)
{
Matrix
::
resizeOrCreate
(
m
,
h
,
w
,
false
,
true
);
m
->
zeroMem
();
...
...
@@ -157,7 +161,7 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
gamma
,
gammaGrad
,
betaGrad
,
EPS
,
eps_
,
savedMean
,
savedInvVar
);
...
...
paddle/gserver/layers/CudnnBatchNormLayer.h
浏览文件 @
65777601
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <cudnn.h>
#include "BatchNormBaseLayer.h"
#include "Layer.h"
#include "paddle/utils/Stat.h"
...
...
@@ -46,12 +47,9 @@ public:
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
)
override
;
protected:
/**
* Epsilon value used in the batch normalization formula.
* Minimum allowed value is CUDNN_BN_MIN_EPSILON defined in cudnn.h.
* Same epsilon value should be used in forward and backward functions.
*/
static
const
double
EPS
;
/// Epsilon value used in the batch normalization formula.
/// Same epsilon value should be used in forward and backward functions.
double
eps_
;
/// Input/output tensor descriptor desc
hl_tensor_descriptor
ioDesc_
;
...
...
paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
浏览文件 @
65777601
...
...
@@ -21,8 +21,6 @@ namespace paddle {
REGISTER_LAYER
(
mkldnn_batch_norm
,
MKLDNNBatchNormLayer
);
const
real
MKLDNNBatchNormLayer
::
EPS
=
1E-5
;
bool
MKLDNNBatchNormLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
if
(
!
MKLDNNLayer
::
init
(
layerMap
,
parameterMap
))
{
...
...
@@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
useGlobalStats_
=
config_
.
use_global_stats
();
}
movingAvgFraction_
=
config_
.
moving_average_fraction
();
epsilon_
=
config_
.
epsilon
();
VLOG
(
MKLDNN_BASE
)
<<
"--- "
<<
(
useGlobalStats_
?
"use"
:
"do not use"
)
<<
" --- global stats"
;
VLOG
(
MKLDNN_BASE
)
<<
"Moving average fraction: "
<<
movingAvgFraction_
;
...
...
@@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD(
if
(
wgt
)
{
flags_
=
(
flags_
|
batch_normalization_flag
::
use_scale_shift
);
}
auto
fwdDesc
=
bn_fwd
::
desc
(
pk
,
in
->
getMemoryDesc
(),
EPS
,
flags_
);
auto
fwdDesc
=
bn_fwd
::
desc
(
pk
,
in
->
getMemoryDesc
(),
epsilon_
,
flags_
);
pd
.
reset
(
new
bn_fwd
::
primitive_desc
(
fwdDesc
,
engine_
));
CHECK_PRIMITIVE_DESC_EQ
(
out
,
pd
->
dst_primitive_desc
());
if
(
wgt
)
{
...
...
@@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD(
}
CHECK_PRIMITIVE_DESC_EQ
(
out
,
in
->
getPrimitiveDesc
());
auto
md
=
in
->
getMemoryDesc
();
auto
bwdDesc
=
bn_bwd
::
desc
(
prop_kind
::
backward
,
md
,
md
,
EPS
,
flags_
);
auto
bwdDesc
=
bn_bwd
::
desc
(
prop_kind
::
backward
,
md
,
md
,
epsilon_
,
flags_
);
pd
.
reset
(
new
bn_bwd
::
primitive_desc
(
bwdDesc
,
engine_
,
*
fwdPD_
));
CHECK
(
pd
->
weights_primitive_desc
()
==
fwdPD_
->
weights_primitive_desc
());
CHECK_PRIMITIVE_DESC_EQ
(
wgt
,
pd
->
diff_weights_primitive_desc
());
...
...
paddle/gserver/layers/MKLDNNBatchNormLayer.h
浏览文件 @
65777601
...
...
@@ -32,7 +32,8 @@ protected:
std
::
shared_ptr
<
bn_fwd
::
primitive_desc
>
fwdPD_
;
// Epsilon value used in the batch normalization formula.
static
const
real
EPS
;
real
epsilon_
;
// weight and bias in paddle
std
::
unique_ptr
<
Weight
>
weight_
;
std
::
unique_ptr
<
Weight
>
biases_
;
...
...
proto/ModelConfig.proto
浏览文件 @
65777601
...
...
@@ -540,6 +540,10 @@ message LayerConfig {
// for switch order layer
optional
ReshapeConfig
reshape_conf
=
59
;
// for batch normalization layer
// The small constant added to the variance to improve numeric stability.
optional
double
epsilon
=
60
[
default
=
0.00001
];
}
message
EvaluatorConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
65777601
...
...
@@ -2412,6 +2412,7 @@ class BatchNormLayer(LayerBase):
bias
=
True
,
img3D
=
False
,
use_global_stats
=
True
,
epsilon
=
1e-5
,
moving_average_fraction
=
0.9
,
batch_norm_type
=
None
,
mean_var_names
=
None
,
...
...
@@ -2460,6 +2461,9 @@ class BatchNormLayer(LayerBase):
self
.
config
.
use_global_stats
=
use_global_stats
if
moving_average_fraction
is
not
None
:
self
.
config
.
moving_average_fraction
=
moving_average_fraction
if
epsilon
is
not
None
:
assert
epsilon
>=
1e-5
,
"epsilon must be no less than 1e-5."
self
.
config
.
epsilon
=
epsilon
input_layer
=
self
.
get_input_layer
(
0
)
image_conf
=
self
.
config
.
inputs
[
0
].
image_conf
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
65777601
...
...
@@ -3118,6 +3118,7 @@ def batch_norm_layer(input,
param_attr
=
None
,
layer_attr
=
None
,
batch_norm_type
=
None
,
epsilon
=
1e-5
,
moving_average_fraction
=
0.9
,
use_global_stats
=
None
,
mean_var_names
=
None
):
...
...
@@ -3188,6 +3189,8 @@ def batch_norm_layer(input,
will use the mean and variance of the current batch
of test data.
:type use_global_stats: bool | None.
:param epsilon: The small constant added to the variance to improve numeric stability.
:type epsilon: float.
:param moving_average_fraction: Factor used in the moving average computation.
:math:`runningMean = newMean*(1-factor) + runningMean*factor`
:type moving_average_fraction: float.
...
...
@@ -3205,6 +3208,7 @@ def batch_norm_layer(input,
assert
(
batch_norm_type
is
None
)
or
(
batch_norm_type
==
"batch_norm"
)
or
\
(
batch_norm_type
==
"mkldnn_batch_norm"
)
or
\
(
batch_norm_type
==
"cudnn_batch_norm"
)
l
=
Layer
(
name
=
name
,
img3D
=
img3D
,
...
...
@@ -3214,6 +3218,7 @@ def batch_norm_layer(input,
type
=
LayerType
.
BATCH_NORM_LAYER
,
batch_norm_type
=
batch_norm_type
,
bias
=
ParamAttr
.
to_bias
(
bias_attr
),
epsilon
=
epsilon
,
moving_average_fraction
=
moving_average_fraction
,
use_global_stats
=
use_global_stats
,
mean_var_names
=
mean_var_names
,
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
浏览文件 @
65777601
...
...
@@ -65,6 +65,7 @@ layers {
height: 227
width: 227
depth: 1
epsilon: 1e-05
}
layers {
name: "__crmnorm_0__"
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
浏览文件 @
65777601
...
...
@@ -65,6 +65,7 @@ layers {
height: 256
width: 256
depth: 1
epsilon: 1e-05
}
layers {
name: "__crmnorm_0__"
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
浏览文件 @
65777601
...
...
@@ -36,6 +36,7 @@ layers {
height: 6
width: 20
depth: 3
epsilon: 1e-05
}
parameters {
name: "___batch_norm_0__.w0"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录