Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
19c554f9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
19c554f9
编写于
1月 19, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
538f1ad2
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
59 addition
and
67 deletion
+59
-67
python/paddle/v2/fluid/clip.py
python/paddle/v2/fluid/clip.py
+38
-44
python/paddle/v2/fluid/tests/test_gradient_clip.py
python/paddle/v2/fluid/tests/test_gradient_clip.py
+21
-23
未找到文件。
python/paddle/v2/fluid/clip.py
浏览文件 @
19c554f9
...
...
@@ -112,58 +112,52 @@ class GradientClipByNorm(BaseGradientClipAttr):
class
GradientClipByGlobalNorm
(
BaseGradientClipAttr
):
global_norm_var
=
None
local_norm_var
=
None
clip_norm_var
=
None
scale_var
=
None
@
classmethod
def
init
(
cls
,
clip_norm
):
if
not
(
isinstance
(
clip_norm
,
int
)
or
isinstance
(
clip_norm
,
float
)):
raise
TypeError
(
"The 'clip_norm' must be a value of int or float"
)
cls
.
global_norm_var
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
"float32"
,
value
=
0.0
)
cls
.
local_norm_var
=
layers
.
create_tensor
(
dtype
=
"float32"
)
cls
.
clip_norm_var
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
"float32"
,
value
=
clip_norm
)
@
classmethod
def
check_init
(
cls
):
if
not
(
isinstance
(
cls
.
global_norm_var
,
framework
.
Variable
)
and
isinstance
(
cls
.
local_norm_var
,
framework
.
Variable
)
and
isinstance
(
cls
.
clip_norm_var
,
framework
.
Variable
)):
raise
ValueError
(
"Class 'GradientClipByGlobalNorm' has not been properly initialized.
\
Please call GradientClipByGlobalNorm.init() first."
)
def
__init__
(
self
,
clip_norm
,
group_name
=
"default_group"
):
if
not
isinstance
(
group_name
,
basestring
):
raise
TypeError
(
"'group_name' must be a basestring."
)
self
.
clip_norm
=
clip_norm
self
.
group_name
=
group_name
def
process_context
(
self
,
context
,
param
,
grad
):
cls
=
self
.
__class__
cls
.
check_init
()
if
self
.
group_name
not
in
context
:
context
[
self
.
group_name
]
=
[]
context
[
self
.
group_name
+
"_clip_value"
]
=
self
.
clip_norm
context
[
self
.
group_name
+
"_clip"
]
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
"float32"
,
value
=
self
.
clip_norm
)
else
:
if
not
self
.
clip_norm
==
context
[
self
.
group_name
+
"_clip_value"
]:
raise
ValueError
(
"All parameters' 'clip_norm' of a same group should be the same"
)
cls
.
local_norm_var
=
layers
.
reduce_sum
(
input
=
layers
.
pow
(
x
=
grad
,
factor
=
2.0
))
layers
.
sums
(
input
=
[
cls
.
local_norm_var
,
cls
.
global_norm_var
],
out
=
[
cls
.
global_norm_var
])
local_norm_var
=
layers
.
reduce_sum
(
input
=
layers
.
pow
(
x
=
grad
,
factor
=
2.0
))
context
[
self
.
group_name
].
append
(
local_norm_var
)
def
create_operators
(
self
,
param
,
grad
):
cls
=
self
.
__class__
cls
.
check_init
()
self
.
context
=
context
if
cls
.
scale_var
is
None
:
layers
.
sqrt
(
x
=
cls
.
global_norm_var
,
out
=
cls
.
global_norm_var
)
cls
.
scale_var
=
layers
.
elementwise_div
(
x
=
cls
.
clip_norm_var
,
def
create_operators
(
self
,
param
,
grad
):
group_scale_name
=
self
.
group_name
+
"_scale"
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
layers
.
sqrt
(
x
=
group_norm_var
,
out
=
group_norm_var
)
clip_var
=
self
.
context
[
self
.
group_name
+
"_clip"
]
group_scale_var
=
layers
.
elementwise_div
(
x
=
clip_var
,
y
=
layers
.
elementwise_max
(
x
=
cls
.
clip_norm_var
,
y
=
cls
.
global_norm_var
))
assert
cls
.
scale_var
.
shape
==
(
1L
,
)
x
=
clip_var
,
y
=
group_norm_var
))
assert
group_scale_var
.
shape
==
(
1L
,
)
self
.
context
[
group_scale_name
]
=
group_scale_var
new_grad
=
layers
.
elementwise_mul
(
x
=
grad
,
y
=
cls
.
scale_var
)
new_grad
=
layers
.
elementwise_mul
(
x
=
grad
,
y
=
self
.
context
[
group_scale_name
])
return
param
,
new_grad
def
gradient_clip_by_global_norm
(
clip_norm
,
param_list
=
None
,
program
=
None
):
def
gradient_clip_by_global_norm
(
clip_norm
,
param_list
=
None
,
group_name
=
"default_group"
,
program
=
None
):
if
program
is
None
:
program
=
framework
.
default_main_program
()
if
param_list
is
None
:
...
...
@@ -175,9 +169,9 @@ def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None):
"'param_list' should be a list of Parameter or basestring(parameter's name)."
)
GradientClipByGlobalNorm
.
init
(
clip_norm
)
for
param
in
param_list
:
param
.
gradient_clip_attr
=
GradientClipByGlobalNorm
()
param
.
gradient_clip_attr
=
GradientClipByGlobalNorm
(
clip_norm
,
group_name
)
def
append_gradient_clip_ops
(
param_grad
):
...
...
python/paddle/v2/fluid/tests/test_gradient_clip.py
浏览文件 @
19c554f9
...
...
@@ -15,21 +15,10 @@ import numpy as np
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
def
_get_global_param_norm_
(
params_grads
):
res
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
"float32"
,
value
=
0.0
)
for
_
,
grad
in
params_grads
:
norm_var
=
fluid
.
layers
.
reduce_sum
(
input
=
fluid
.
layers
.
pow
(
x
=
grad
,
factor
=
2.0
))
fluid
.
layers
.
sums
(
input
=
[
norm_var
,
res
],
out
=
[
res
])
fluid
.
layers
.
sqrt
(
x
=
res
,
out
=
res
)
return
res
BATCH_SIZE
=
128
CLIP
=
0.5
prog
=
fluid
.
framework
.
Program
()
CLIP
=
1
prog
=
fluid
.
framework
.
Program
()
with
fluid
.
program_guard
(
main_program
=
prog
):
image
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
784
],
dtype
=
'float32'
)
...
...
@@ -49,13 +38,12 @@ avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
p_g
=
fluid
.
backward
.
append_backward
(
loss
=
avg_cost
)
p_g_clip
=
fluid
.
backward
.
append_backward
(
loss
=
avg_cost_clip
)
with
fluid
.
program_guard
(
main_program
=
prog
):
gloabl_norm
=
_get_global_param_norm_
(
p_g
)
with
fluid
.
program_guard
(
main_program
=
prog_clip
):
fluid
.
clip
.
gradient_clip_by_global_norm
(
clip_norm
=
CLIP
)
p_g_clip
=
fluid
.
clip
.
append_gradient_clip_ops
(
p_g_clip
)
gloabl_norm_clip
=
_get_global_param_norm_
(
p_g_clip
)
grad_list
=
[
elem
[
1
]
for
elem
in
p_g
]
grad_clip_list
=
[
elem
[
1
]
for
elem
in
p_g_clip
]
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
...
...
@@ -72,11 +60,21 @@ for data in train_reader():
count
+=
1
if
count
>
5
:
break
out
,
=
exe
.
run
(
prog
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
gloabl_norm
]
)
out_clip
,
=
exe
.
run
(
prog_clip
,
out
=
exe
.
run
(
prog
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
grad_list
)
out_clip
=
exe
.
run
(
prog_clip
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
gloabl_norm_clip
])
if
not
np
.
allclose
(
out_clip
,
np
.
minimum
(
out
,
np
.
array
([
CLIP
]))):
fetch_list
=
grad_clip_list
)
global_norm
=
0
for
v
in
out
[
1
:]:
global_norm
+=
np
.
sum
(
np
.
power
(
v
,
2
))
global_norm
=
np
.
sqrt
(
global_norm
)
global_norm_clip
=
0
for
v
in
out_clip
[
1
:]:
global_norm_clip
+=
np
.
sum
(
np
.
power
(
v
,
2
))
global_norm_clip
=
np
.
sqrt
(
global_norm_clip
)
if
not
np
.
isclose
(
a
=
global_norm_clip
,
b
=
np
.
minimum
(
global_norm
,
CLIP
),
rtol
=
5e-3
):
exit
(
1
)
exit
(
0
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录