Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9c63b7c1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c63b7c1
编写于
12月 03, 2019
作者:
K
Kaipeng Deng
提交者:
GitHub
12月 03, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick] add bn momentum variable (#21435)
* batch_norm momentum support variable. test=develop
上级
5c7c6b1e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
139 addition
and
46 deletion
+139
-46
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+20
-1
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+10
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+52
-19
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+44
-25
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+13
-0
未找到文件。
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
9c63b7c1
...
...
@@ -58,6 +58,13 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"data_layout"
));
if
(
ctx
->
IsRuntime
()
&&
ctx
->
HasInput
(
"MomentumTensor"
))
{
auto
mom
=
ctx
->
Inputs
(
"MomentumTensor"
);
PADDLE_ENFORCE_EQ
(
mom
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(MomentumTensor) size must be 1"
));
}
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
"ShapeError: the dimension of input X must greater than or equal to 2."
...
...
@@ -173,6 +180,11 @@ void BatchNormOpMaker::Make() {
AddInput
(
"Variance"
,
"The global variance (for training) "
"or estimated Variance (for testing)"
);
AddInput
(
"MomentumTensor"
,
"(Tensor<float32>, optional) If provided, batch_norm will "
"use this as momentum, this has a higher priority than "
"attr(momentum), the shape of this tensor MUST BE [1]."
)
.
AsDispensable
();
AddOutput
(
"Y"
,
"result after normalization"
);
AddOutput
(
"MeanOut"
,
"Share memory with Mean. "
...
...
@@ -221,7 +233,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
...
...
@@ -306,6 +318,13 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
PADDLE_THROW
(
"Unknown storage order: %s"
,
data_layout_str
);
}
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
if
(
ctx
.
HasInput
(
"MomentumTensor"
))
{
const
auto
*
mom_tensor
=
ctx
.
Input
<
Tensor
>
(
"MomentumTensor"
);
momentum
=
mom_tensor
->
data
<
float
>
()[
0
];
}
running_mean_arr
=
running_mean_arr
*
momentum
+
saved_mean_e
*
(
1.
-
momentum
);
running_var_arr
=
...
...
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
9c63b7c1
...
...
@@ -43,7 +43,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace."
);
double
epsilon
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
const
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
const
std
::
string
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
...
...
@@ -133,6 +133,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
epsilon
));
}
else
{
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
if
(
ctx
.
HasInput
(
"MomentumTensor"
))
{
const
auto
*
mom_tensor
=
ctx
.
Input
<
Tensor
>
(
"MomentumTensor"
);
Tensor
mom_cpu
;
TensorCopySync
(
*
mom_tensor
,
platform
::
CPUPlace
(),
&
mom_cpu
);
momentum
=
mom_cpu
.
data
<
float
>
()[
0
];
}
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// initialize them.
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
9c63b7c1
...
...
@@ -4176,13 +4176,14 @@ def batch_norm(input,
sync_batch_norm automatically.
Args:
input(
v
ariable): The rank of input variable can be 2, 3, 4, 5. The data type
input(
V
ariable): The rank of input variable can be 2, 3, 4, 5. The data type
is float16 or float32 or float64.
act(string, Default None): Activation type, linear|relu|prelu|...
is_test (bool, Default False): A flag indicating whether it is in
test phrase or not.
momentum(float, Default 0.9): The value used for the moving_mean and
moving_var computation. The updated formula is:
momentum(float|Variable, Default 0.9): The value used for the moving_mean and
moving_var computation. This should be a float number or a Variable with
shape [1] and data type as float32. The updated formula is:
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
Default is 0.9.
...
...
@@ -4228,6 +4229,33 @@ def batch_norm(input,
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.batch_norm(input=hidden1)
.. code-block:: python
# batch_norm with momentum as Variable
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
def get_decay_momentum(momentum_init, decay_steps, decay_rate):
global_step = lr_scheduler._decay_step_counter()
momentum = fluid.layers.create_global_var(
shape=[1],
value=float(momentum_init),
dtype='float32',
# set persistable for save checkpoints and resume
persistable=True,
name="momentum")
div_res = global_step / decay_steps
decayed_momentum = momentum_init * (decay_rate**div_res)
fluid.layers.assign(decayed_momentum, momentum)
return momentum
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
momentum = get_decay_momentum(0.9, 1e5, 0.9)
hidden2 = fluid.layers.batch_norm(input=hidden1, momentum=momentum)
"""
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
helper = LayerHelper('batch_norm', **locals())
...
...
@@ -4303,15 +4331,28 @@ def batch_norm(input,
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype)
inputs = {
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
}
attrs = {
"epsilon": epsilon,
"is_test": is_test,
"data_layout": data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats
}
if isinstance(momentum, Variable):
inputs['MomemtumTensor'] = momentum
else:
attrs['momentum'] = momentum
helper.append_op(
type="batch_norm",
inputs={
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
},
inputs=inputs,
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
...
...
@@ -4319,15 +4360,7 @@ def batch_norm(input,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={
"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test,
"data_layout": data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats
})
attrs=attrs)
return helper.append_activation(batch_norm_out)
...
...
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
浏览文件 @
9c63b7c1
...
...
@@ -310,6 +310,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
self
.
fuse_with_relu
=
False
self
.
data_formats
=
[
"NCHW"
,
"NHWC"
]
self
.
momentum
=
0.9
self
.
use_momentum_variable
=
False
self
.
epsilon
=
0.00001
self
.
init_kernel_type
()
self
.
init_test_case
()
...
...
@@ -367,6 +368,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
bias
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
mean
,
variance
=
self
.
set_mean_variance
(
scale_shape
,
x
,
data_layout
)
y_grad
=
np
.
random
.
random_sample
(
shape
).
astype
(
np
.
float32
)
momentum_var
=
np
.
array
([
momentum
]).
astype
(
np
.
float32
)
y
,
mean_out
,
variance_out
,
saved_mean
,
saved_variance
,
x_grad
,
scale_grad
,
bias_grad
=
self
.
ref_forward_backward
(
x
,
y_grad
,
scale
,
bias
,
mean
,
variance
,
epsilon
,
momentum
,
...
...
@@ -380,7 +382,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
var_names
=
[
'x'
,
'scale'
,
'bias'
,
'mean'
,
'variance'
,
'y'
,
'saved_mean'
,
'saved_variance'
'saved_variance'
,
'momentum_var'
]
ground_truth
=
{
name
:
var_dict
[
name
]
for
name
in
var_names
}
...
...
@@ -392,15 +394,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
name
=
name
,
dtype
=
'float32'
,
shape
=
ground_truth
[
name
].
shape
)
inputs
=
{
"X"
:
block
.
var
(
'x'
),
"Scale"
:
block
.
var
(
'scale'
),
"Bias"
:
block
.
var
(
'bias'
),
"Mean"
:
block
.
var
(
'mean'
),
"Variance"
:
block
.
var
(
'variance'
)
}
attrs
=
{
"epsilon"
:
epsilon
,
"is_test"
:
False
,
"data_layout"
:
data_layout
,
"use_mkldnn"
:
self
.
use_mkldnn
,
"fuse_with_relu"
:
self
.
fuse_with_relu
,
"use_global_stats"
:
self
.
use_global_stats
}
if
self
.
use_momentum_variable
:
inputs
[
'MomentumTensor'
]
=
block
.
var
(
'momentum_var'
)
else
:
attrs
[
'momentum'
]
=
momentum
bn_op
=
block
.
append_op
(
type
=
"batch_norm"
,
inputs
=
{
"X"
:
block
.
var
(
'x'
),
"Scale"
:
block
.
var
(
'scale'
),
"Bias"
:
block
.
var
(
'bias'
),
"Mean"
:
block
.
var
(
'mean'
),
"Variance"
:
block
.
var
(
'variance'
)
},
inputs
=
inputs
,
outputs
=
{
"Y"
:
block
.
var
(
'y'
),
"MeanOut"
:
block
.
var
(
'mean'
),
# share memory
...
...
@@ -408,15 +423,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
"SavedMean"
:
block
.
var
(
'saved_mean'
),
"SavedVariance"
:
block
.
var
(
'saved_variance'
)
},
attrs
=
{
"momentum"
:
momentum
,
"epsilon"
:
epsilon
,
"is_test"
:
False
,
"data_layout"
:
data_layout
,
"use_mkldnn"
:
self
.
use_mkldnn
,
"fuse_with_relu"
:
self
.
fuse_with_relu
,
"use_global_stats"
:
self
.
use_global_stats
})
attrs
=
attrs
)
block
.
create_var
(
name
=
'y@GRAD'
,
dtype
=
'float32'
,
shape
=
y
.
shape
)
# generate backward op_desc
...
...
@@ -434,14 +441,15 @@ class TestBatchNormOpTraining(unittest.TestCase):
grad_var
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
program
,
feed
=
{
name
:
var_dict
[
name
]
for
name
in
[
'x'
,
'scale'
,
'bias'
,
'mean'
,
'variance'
,
'y@GRAD'
]
},
fetch_list
=
self
.
fetch_list
)
out
=
exe
.
run
(
program
,
feed
=
{
name
:
var_dict
[
name
]
for
name
in
[
'x'
,
'scale'
,
'bias'
,
'mean'
,
'variance'
,
'y@GRAD'
,
'momentum_var'
]
},
fetch_list
=
self
.
fetch_list
)
for
id
,
name
in
enumerate
(
self
.
fetch_list
):
if
name
==
'variance'
:
...
...
@@ -471,6 +479,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'x@GRAD'
]
class
TestBatchNormOpTrainingMomentumVariable
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_momentum_variable
=
True
self
.
use_global_stats
=
False
self
.
no_grad_set
=
set
()
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'saved_mean'
,
'saved_variance'
,
'x@GRAD'
,
'scale@GRAD'
,
'bias@GRAD'
]
class
TestBatchNormOpFreezeStatsTraining
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_global_stats
=
True
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
9c63b7c1
...
...
@@ -2465,6 +2465,19 @@ class TestBook(LayerTest):
out
=
layers
.
batch_norm
(
data
)
return
(
out
)
def
make_batch_norm_momentum_variable
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
data
=
self
.
_get_data
(
name
=
'data'
,
shape
=
[
32
,
128
,
128
],
dtype
=
"float32"
)
momentum
=
self
.
_get_data
(
name
=
'momentum'
,
shape
=
[
1
],
dtype
=
'float32'
,
append_batch_size
=
False
)
out
=
layers
.
batch_norm
(
data
,
momentum
=
momentum
)
return
(
out
)
def
make_range
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录