Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4dccb584
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4dccb584
编写于
7月 12, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Hide clip APIs
上级
0c8f69c3
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
15 addition
and
15 deletion
+15
-15
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+15
-15
未找到文件。
python/paddle/fluid/clip.py
浏览文件 @
4dccb584
...
...
@@ -31,7 +31,7 @@ class BaseErrorClipAttr(object):
def
__str__
(
self
):
raise
NotImplementedError
()
def
append_clip_op
(
self
,
block
,
grad_name
):
def
_
append_clip_op
(
self
,
block
,
grad_name
):
raise
NotImplementedError
()
...
...
@@ -67,7 +67,7 @@ class ErrorClipByValue(BaseErrorClipAttr):
def
__str__
(
self
):
return
"ByValue, min=%f, max=%f"
%
(
self
.
min
,
self
.
max
)
def
append_clip_op
(
self
,
block
,
grad_name
):
def
_
append_clip_op
(
self
,
block
,
grad_name
):
clip_op_desc
=
block
.
desc
.
append_op
()
clip_op_desc
.
set_type
(
"clip"
)
clip_op_desc
.
set_input
(
"X"
,
[
grad_name
])
...
...
@@ -90,17 +90,17 @@ def error_clip_callback(block, context):
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
)
if
error_clip
is
not
None
:
error_clip
.
append_clip_op
(
block
,
grad_n
)
error_clip
.
_
append_clip_op
(
block
,
grad_n
)
class
BaseGradientClipAttr
(
object
):
def
__str__
(
self
):
raise
NotImplementedError
()
def
process_context
(
self
,
context
,
param
,
grad
):
def
_
process_context
(
self
,
context
,
param
,
grad
):
raise
NotImplementedError
()
def
create_operators
(
self
,
param
,
grad
):
def
_
create_operators
(
self
,
param
,
grad
):
raise
NotImplementedError
()
...
...
@@ -108,10 +108,10 @@ class NullGradientClipAttr(BaseGradientClipAttr):
def
__str__
(
self
):
return
"Null"
def
process_context
(
self
,
context
,
param
,
grad
):
def
_
process_context
(
self
,
context
,
param
,
grad
):
pass
def
create_operators
(
self
,
param
,
grad
):
def
_
create_operators
(
self
,
param
,
grad
):
return
param
,
grad
...
...
@@ -153,10 +153,10 @@ class GradientClipByValue(BaseGradientClipAttr):
def
__str__
(
self
):
return
"ByValue, min=%f, max=%f"
%
(
self
.
min
,
self
.
max
)
def
process_context
(
self
,
context
,
param
,
grad
):
def
_
process_context
(
self
,
context
,
param
,
grad
):
pass
def
create_operators
(
self
,
param
,
grad
):
def
_
create_operators
(
self
,
param
,
grad
):
new_grad
=
layers
.
clip
(
x
=
grad
,
min
=
self
.
min
,
max
=
self
.
max
)
return
param
,
new_grad
...
...
@@ -199,10 +199,10 @@ class GradientClipByNorm(BaseGradientClipAttr):
def
__str__
(
self
):
return
"ByNorm, clip_norm=%f"
%
self
.
clip_norm
def
process_context
(
self
,
context
,
param
,
grad
):
def
_
process_context
(
self
,
context
,
param
,
grad
):
pass
def
create_operators
(
self
,
param
,
grad
):
def
_
create_operators
(
self
,
param
,
grad
):
new_grad
=
layers
.
clip_by_norm
(
x
=
grad
,
max_norm
=
self
.
clip_norm
)
return
param
,
new_grad
...
...
@@ -257,7 +257,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return
"ByGlobalNorm, group_name=%s, clip_norm=%f"
%
(
self
.
group_name
,
self
.
clip_norm
)
def
process_context
(
self
,
context
,
param
,
grad
):
def
_
process_context
(
self
,
context
,
param
,
grad
):
if
self
.
group_name
not
in
context
:
context
[
self
.
group_name
]
=
[]
context
[
self
.
group_name
+
"_clip_value"
]
=
self
.
clip_norm
...
...
@@ -274,7 +274,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
self
.
context
=
context
def
create_operators
(
self
,
param
,
grad
):
def
_
create_operators
(
self
,
param
,
grad
):
group_scale_name
=
self
.
group_name
+
"_scale"
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
...
...
@@ -336,12 +336,12 @@ def append_gradient_clip_ops(param_grad):
"clip attribute should be an instance of BaseGradientClipAttr"
)
clip_attr
.
process_context
(
context
=
context
,
param
=
p
,
grad
=
g
)
clip_attr
.
_
process_context
(
context
=
context
,
param
=
p
,
grad
=
g
)
res
=
[]
for
p
,
g
in
param_grad
:
with
p
.
block
.
program
.
optimized_guard
(
p
):
res
.
append
(
clip_attr
.
create_operators
(
param
=
p
,
grad
=
g
))
res
.
append
(
clip_attr
.
_
create_operators
(
param
=
p
,
grad
=
g
))
return
res
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录