提交 63b58dc2 编写于 作者: S sandyhouse

update optimizer, test=develop

上级 96aa0973
......@@ -47,8 +47,9 @@ __all__ = [
'AdamOptimizer', 'AdamaxOptimizer', 'DpsgdOptimizer',
'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta',
'AdadeltaOptimizer', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer', 'LambOptimizer', 'ExponentialMovingAverage',
'PipelineOptimizer', 'LookaheadOptimizer', 'RecomputeOptimizer'
'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer',
'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer',
'RecomputeOptimizer'
]
......@@ -3771,30 +3772,30 @@ class PipelineOptimizer(object):
return programs
def _find_post_op(self, ops, cur_op, var_name):
"""
Find the real post op that has variable named var_name as input.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has variable named
var_name as output.
var_name (string): Variable name.
"""
post_op = []
before = True
for op in ops:
if op == cur_op:
before = False
continue
if before:
continue
for in_var_name in op.input_arg_names:
if in_var_name == var_name:
post_op.append(op)
if post_op:
return post_op[0]
return None
#def _find_post_op(self, ops, cur_op, var_name):
# """
# Find the real post op that has variable named var_name as input.
# Args:
# ops (list): A list of ops.
# cur_op (Operator): Current operator which has variable named
# var_name as output.
# var_name (string): Variable name.
# """
# post_op = []
# before = True
# for op in ops:
# if op == cur_op:
# before = False
# continue
# if before:
# continue
# for in_var_name in op.input_arg_names:
# if in_var_name == var_name:
# post_op.append(op)
# if post_op:
# return post_op[0]
# return None
def _find_real_prev_op(self, ops, cur_op, var_name):
"""
......@@ -3972,7 +3973,7 @@ class PipelineOptimizer(object):
assert self._op_role_var_key in op.attr_names
op_role_var = op.all_attrs()[self._op_role_var_key]
assert len(op_role_var) == 2
param_name = block.vars[op_role_var[0]].name
param_name = op_role_var[0]
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
......@@ -4008,8 +4009,12 @@ class PipelineOptimizer(object):
assert '@RENAME@' in name
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(block.ops, op, out_name)
device = post_op.attr(self._op_device_key)
assert core.grad_var_suffix() in out_name
param_name = self._strip_grad_suffix(out_name)
assert param_name in self._param_device_map
device = self._param_device_map[param_name]
#post_op = self._find_post_op(block.ops, op, out_name)
#device = post_op.attr(self._op_device_key)
assert device
op._set_attr(self._op_device_key, device)
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册