Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
610ad495
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
610ad495
编写于
1月 22, 2018
作者:
F
fengjiayi
提交者:
GitHub
1月 22, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7637 from JiayiFeng/dev_global_norm_clip
Gradient clip by global norm
上级
f45b0b06
e8adcaf2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
173 addition
and
27 deletion
+173
-27
python/paddle/v2/fluid/clip.py
python/paddle/v2/fluid/clip.py
+82
-7
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+2
-2
python/paddle/v2/fluid/layers/ops.py
python/paddle/v2/fluid/layers/ops.py
+4
-14
python/paddle/v2/fluid/param_attr.py
python/paddle/v2/fluid/param_attr.py
+3
-3
python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
...n/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
+1
-1
python/paddle/v2/fluid/tests/test_error_clip.py
python/paddle/v2/fluid/tests/test_error_clip.py
+0
-0
python/paddle/v2/fluid/tests/test_gradient_clip.py
python/paddle/v2/fluid/tests/test_gradient_clip.py
+81
-0
未找到文件。
python/paddle/v2/fluid/clip.py
浏览文件 @
610ad495
...
...
@@ -14,6 +14,7 @@
import
functools
import
layers
import
framework
from
.
import
core
__all__
=
[
...
...
@@ -66,7 +67,7 @@ def error_clip_callback(block, context):
class
BaseGradientClipAttr
(
object
):
def
process_context
(
self
,
context
,
p
_g
):
def
process_context
(
self
,
context
,
p
aram
,
grad
):
raise
NotImplementedError
()
def
create_operators
(
self
,
param
,
grad
):
...
...
@@ -74,7 +75,7 @@ class BaseGradientClipAttr(object):
class
NullGradientClipAttr
(
BaseGradientClipAttr
):
def
process_context
(
self
,
context
,
p
_g
):
def
process_context
(
self
,
context
,
p
aram
,
grad
):
pass
def
create_operators
(
self
,
param
,
grad
):
...
...
@@ -91,7 +92,7 @@ class GradientClipByValue(BaseGradientClipAttr):
self
.
max
=
max
self
.
min
=
min
def
process_context
(
self
,
context
,
p
_g
):
def
process_context
(
self
,
context
,
p
aram
,
grad
):
pass
def
create_operators
(
self
,
param
,
grad
):
...
...
@@ -99,19 +100,93 @@ class GradientClipByValue(BaseGradientClipAttr):
return
param
,
new_grad
class
GradientClipByNorm
(
BaseGradientClipAttr
):
def
__init__
(
self
,
clip_norm
):
self
.
clip_norm
=
clip_norm
def
process_context
(
self
,
context
,
param
,
grad
):
pass
def
create_operators
(
self
,
param
,
grad
):
new_grad
=
layers
.
clip_by_norm
(
x
=
grad
,
max_norm
=
self
.
clip_norm
)
return
param
,
new_grad
class
GradientClipByGlobalNorm
(
BaseGradientClipAttr
):
def
__init__
(
self
,
clip_norm
,
group_name
=
"default_group"
):
if
not
isinstance
(
group_name
,
basestring
):
raise
TypeError
(
"'group_name' must be a basestring."
)
self
.
clip_norm
=
clip_norm
self
.
group_name
=
group_name
def
process_context
(
self
,
context
,
param
,
grad
):
if
self
.
group_name
not
in
context
:
context
[
self
.
group_name
]
=
[]
context
[
self
.
group_name
+
"_clip_value"
]
=
self
.
clip_norm
context
[
self
.
group_name
+
"_clip"
]
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
"float32"
,
value
=
self
.
clip_norm
)
else
:
if
not
self
.
clip_norm
==
context
[
self
.
group_name
+
"_clip_value"
]:
raise
ValueError
(
"All parameters' 'clip_norm' of a same group should be the same"
)
local_norm_var
=
layers
.
reduce_sum
(
input
=
layers
.
pow
(
x
=
grad
,
factor
=
2.0
))
context
[
self
.
group_name
].
append
(
local_norm_var
)
self
.
context
=
context
def
create_operators
(
self
,
param
,
grad
):
group_scale_name
=
self
.
group_name
+
"_scale"
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
layers
.
sqrt
(
x
=
group_norm_var
,
out
=
group_norm_var
)
clip_var
=
self
.
context
[
self
.
group_name
+
"_clip"
]
group_scale_var
=
layers
.
elementwise_div
(
x
=
clip_var
,
y
=
layers
.
elementwise_max
(
x
=
clip_var
,
y
=
group_norm_var
))
assert
group_scale_var
.
shape
==
(
1L
,
)
self
.
context
[
group_scale_name
]
=
group_scale_var
new_grad
=
layers
.
elementwise_mul
(
x
=
grad
,
y
=
self
.
context
[
group_scale_name
])
return
param
,
new_grad
def
gradient_clip_by_global_norm
(
clip_norm
,
param_list
=
None
,
group_name
=
"default_group"
,
program
=
None
):
if
program
is
None
:
program
=
framework
.
default_main_program
()
if
param_list
is
None
:
param_list
=
program
.
block
(
0
).
all_parameters
()
if
all
(
isinstance
(
elem
,
basestring
)
for
elem
in
param_list
):
param_list
=
[
program
.
block
(
0
).
var
(
elem
)
for
elem
in
param_list
]
if
not
all
(
isinstance
(
elem
,
framework
.
Parameter
)
for
elem
in
param_list
):
raise
TypeError
(
"'param_list' should be a list of Parameter or basestring(parameter's name)."
)
for
param
in
param_list
:
param
.
gradient_clip_attr
=
GradientClipByGlobalNorm
(
clip_norm
,
group_name
)
def
append_gradient_clip_ops
(
param_grad
):
context
=
dict
()
create_op_callbacks
=
[]
for
p
,
g
in
param_grad
:
clip_attr
=
getattr
(
p
,
'clip_attr'
,
NullGradientClipAttr
())
clip_attr
=
getattr
(
p
,
'
gradient_
clip_attr'
,
NullGradientClipAttr
())
if
clip_attr
is
None
:
clip_attr
=
NullGradientClipAttr
()
if
not
isinstance
(
clip_attr
,
BaseGradientClipAttr
):
raise
TypeError
(
"clip attribute should be an instance of BaseGradientClippingAttr"
)
"clip attribute should be an instance of BaseGradientClipAttr"
)
clip_attr
.
process_context
(
context
=
context
,
p
_g
=
param_grad
)
clip_attr
.
process_context
(
context
=
context
,
p
aram
=
p
,
grad
=
g
)
create_op_callbacks
.
append
(
functools
.
partial
(
clip_attr
.
create_operators
,
param
=
p
,
grad
=
g
))
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
610ad495
...
...
@@ -780,7 +780,7 @@ class Block(object):
trainable
=
p
.
trainable
,
optimize_attr
=
p
.
optimize_attr
,
regularizer
=
p
.
regularizer
,
clip_attr
=
p
.
clip_attr
,
gradient_clip_attr
=
p
.
gradient_
clip_attr
,
error_clip
=
p
.
error_clip
,
name
=
v
.
name
)
self
.
vars
[
new_p
.
name
]
=
new_p
...
...
@@ -948,7 +948,7 @@ class Parameter(Variable):
self
.
regularizer
=
kwargs
.
get
(
'regularizer'
,
None
)
self
.
clip_attr
=
kwargs
.
get
(
'
clip_attr'
,
None
)
self
.
gradient_clip_attr
=
kwargs
.
get
(
'gradient_
clip_attr'
,
None
)
# program is a global instance.
...
...
python/paddle/v2/fluid/layers/ops.py
浏览文件 @
610ad495
...
...
@@ -46,20 +46,10 @@ __activations__ = [
]
__all__
=
[
'mean'
,
'mul'
,
'reshape'
,
'scale'
,
'transpose'
,
'sigmoid_cross_entropy_with_logits'
,
'elementwise_add'
,
'elementwise_div'
,
'elementwise_sub'
,
'elementwise_mul'
,
'elementwise_max'
,
'elementwise_min'
,
'clip'
,
'sequence_softmax'
,
'mean'
,
'mul'
,
'reshape'
,
'scale'
,
'transpose'
,
'sigmoid_cross_entropy_with_logits'
,
'elementwise_add'
,
'elementwise_div'
,
'elementwise_sub'
,
'elementwise_mul'
,
'elementwise_max'
,
'elementwise_min'
,
'clip'
,
'clip_by_norm'
,
'sequence_softmax'
]
+
__activations__
for
_OP
in
set
(
__all__
):
...
...
python/paddle/v2/fluid/param_attr.py
浏览文件 @
610ad495
...
...
@@ -25,13 +25,13 @@ class ParamAttr(object):
learning_rate
=
1.0
,
regularizer
=
None
,
trainable
=
True
,
clip
=
None
):
gradient_
clip
=
None
):
self
.
name
=
name
self
.
initializer
=
initializer
self
.
learning_rate
=
learning_rate
self
.
regularizer
=
regularizer
self
.
trainable
=
trainable
self
.
clip
=
clip
self
.
gradient_clip
=
gradient_
clip
def
set_default_initializer
(
self
,
initializer
):
if
initializer
is
None
:
...
...
@@ -77,7 +77,7 @@ class ParamAttr(object):
},
'regularizer'
:
self
.
regularizer
,
'trainable'
:
self
.
trainable
,
'
clip_attr'
:
self
.
clip
'
gradient_clip_attr'
:
self
.
gradient_
clip
}
if
with_initializer
:
kwargs
[
'initializer'
]
=
self
.
initializer
...
...
python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py
浏览文件 @
610ad495
...
...
@@ -27,7 +27,7 @@ hidden1 = fluid.layers.fc(input=image,
act
=
'relu'
,
param_attr
=
fluid
.
ParamAttr
(
regularizer
=
regularizer
,
clip
=
fluid
.
clip
.
ClipByValue
(
10
)))
gradient_
clip
=
fluid
.
clip
.
ClipByValue
(
10
)))
hidden2
=
fluid
.
layers
.
fc
(
input
=
hidden1
,
size
=
64
,
...
...
python/paddle/v2/fluid/tests/test_clip.py
→
python/paddle/v2/fluid/tests/test_
error_
clip.py
浏览文件 @
610ad495
文件已移动
python/paddle/v2/fluid/tests/test_gradient_clip.py
0 → 100644
浏览文件 @
610ad495
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
BATCH_SIZE
=
128
CLIP
=
1
prog
=
fluid
.
framework
.
Program
()
with
fluid
.
program_guard
(
main_program
=
prog
):
image
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
784
],
dtype
=
'float32'
)
hidden1
=
fluid
.
layers
.
fc
(
input
=
image
,
size
=
128
,
act
=
'relu'
)
hidden2
=
fluid
.
layers
.
fc
(
input
=
hidden1
,
size
=
64
,
act
=
'relu'
)
predict
=
fluid
.
layers
.
fc
(
input
=
hidden2
,
size
=
10
,
act
=
'softmax'
)
label
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'int64'
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
prog_clip
=
prog
.
clone
()
avg_cost_clip
=
prog_clip
.
block
(
0
).
var
(
avg_cost
.
name
)
p_g
=
fluid
.
backward
.
append_backward
(
loss
=
avg_cost
)
p_g_clip
=
fluid
.
backward
.
append_backward
(
loss
=
avg_cost_clip
)
with
fluid
.
program_guard
(
main_program
=
prog_clip
):
fluid
.
clip
.
gradient_clip_by_global_norm
(
clip_norm
=
CLIP
)
p_g_clip
=
fluid
.
clip
.
append_gradient_clip_ops
(
p_g_clip
)
grad_list
=
[
elem
[
1
]
for
elem
in
p_g
]
grad_clip_list
=
[
elem
[
1
]
for
elem
in
p_g_clip
]
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
8192
),
batch_size
=
BATCH_SIZE
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
image
,
label
],
place
=
place
)
exe
.
run
(
fluid
.
default_startup_program
())
count
=
0
for
data
in
train_reader
():
count
+=
1
if
count
>
5
:
break
out
=
exe
.
run
(
prog
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
grad_list
)
out_clip
=
exe
.
run
(
prog_clip
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
grad_clip_list
)
global_norm
=
0
for
v
in
out
[
1
:]:
global_norm
+=
np
.
sum
(
np
.
power
(
v
,
2
))
global_norm
=
np
.
sqrt
(
global_norm
)
global_norm_clip
=
0
for
v
in
out_clip
[
1
:]:
global_norm_clip
+=
np
.
sum
(
np
.
power
(
v
,
2
))
global_norm_clip
=
np
.
sqrt
(
global_norm_clip
)
if
not
np
.
isclose
(
a
=
global_norm_clip
,
b
=
np
.
minimum
(
global_norm
,
CLIP
),
rtol
=
5e-3
):
exit
(
1
)
exit
(
0
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录