Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
63b58dc2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
63b58dc2
编写于
8月 06, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update optimizer, test=develop
上级
96aa0973
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
34 addition
and
29 deletion
+34
-29
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+34
-29
未找到文件。
python/paddle/fluid/optimizer.py
浏览文件 @
63b58dc2
...
...
@@ -47,8 +47,9 @@ __all__ = [
'AdamOptimizer'
,
'AdamaxOptimizer'
,
'DpsgdOptimizer'
,
'DecayedAdagradOptimizer'
,
'RMSPropOptimizer'
,
'FtrlOptimizer'
,
'Adadelta'
,
'AdadeltaOptimizer'
,
'ModelAverage'
,
'LarsMomentum'
,
'LarsMomentumOptimizer'
,
'LambOptimizer'
,
'ExponentialMovingAverage'
,
'PipelineOptimizer'
,
'LookaheadOptimizer'
,
'RecomputeOptimizer'
'LarsMomentumOptimizer'
,
'DGCMomentumOptimizer'
,
'LambOptimizer'
,
'ExponentialMovingAverage'
,
'PipelineOptimizer'
,
'LookaheadOptimizer'
,
'RecomputeOptimizer'
]
...
...
@@ -3771,30 +3772,30 @@ class PipelineOptimizer(object):
return
programs
def
_find_post_op
(
self
,
ops
,
cur_op
,
var_name
):
"""
Find the real post op that has variable named var_name as input.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has variable named
var_name as output.
var_name (string): Variable name.
"""
post_op
=
[]
before
=
True
for
op
in
ops
:
if
op
==
cur_op
:
before
=
False
continue
if
before
:
continue
for
in_var_name
in
op
.
input_arg_names
:
if
in_var_name
==
var_name
:
post_op
.
append
(
op
)
if
post_op
:
return
post_op
[
0
]
return
None
#
def _find_post_op(self, ops, cur_op, var_name):
#
"""
#
Find the real post op that has variable named var_name as input.
#
Args:
#
ops (list): A list of ops.
#
cur_op (Operator): Current operator which has variable named
#
var_name as output.
#
var_name (string): Variable name.
#
"""
#
post_op = []
#
before = True
#
for op in ops:
#
if op == cur_op:
#
before = False
#
continue
#
if before:
#
continue
#
for in_var_name in op.input_arg_names:
#
if in_var_name == var_name:
#
post_op.append(op)
#
if post_op:
#
return post_op[0]
#
return None
def
_find_real_prev_op
(
self
,
ops
,
cur_op
,
var_name
):
"""
...
...
@@ -3972,7 +3973,7 @@ class PipelineOptimizer(object):
assert
self
.
_op_role_var_key
in
op
.
attr_names
op_role_var
=
op
.
all_attrs
()[
self
.
_op_role_var_key
]
assert
len
(
op_role_var
)
==
2
param_name
=
block
.
vars
[
op_role_var
[
0
]].
name
param_name
=
op_role_var
[
0
]
device
=
self
.
_param_device_map
[
param_name
]
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
...
...
@@ -4008,8 +4009,12 @@ class PipelineOptimizer(object):
assert
'@RENAME@'
in
name
assert
len
(
op
.
desc
.
output_arg_names
())
==
1
out_name
=
op
.
desc
.
output_arg_names
()[
0
]
post_op
=
self
.
_find_post_op
(
block
.
ops
,
op
,
out_name
)
device
=
post_op
.
attr
(
self
.
_op_device_key
)
assert
core
.
grad_var_suffix
()
in
out_name
param_name
=
self
.
_strip_grad_suffix
(
out_name
)
assert
param_name
in
self
.
_param_device_map
device
=
self
.
_param_device_map
[
param_name
]
#post_op = self._find_post_op(block.ops, op, out_name)
#device = post_op.attr(self._op_device_key)
assert
device
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
continue
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录