提交 63b58dc2 编写于 作者: S sandyhouse

update optimizer, test=develop

上级 96aa0973
...@@ -47,8 +47,9 @@ __all__ = [ ...@@ -47,8 +47,9 @@ __all__ = [
'AdamOptimizer', 'AdamaxOptimizer', 'DpsgdOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DpsgdOptimizer',
'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta',
'AdadeltaOptimizer', 'ModelAverage', 'LarsMomentum', 'AdadeltaOptimizer', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer', 'LambOptimizer', 'ExponentialMovingAverage', 'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer',
'PipelineOptimizer', 'LookaheadOptimizer', 'RecomputeOptimizer' 'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer',
'RecomputeOptimizer'
] ]
...@@ -3771,30 +3772,30 @@ class PipelineOptimizer(object): ...@@ -3771,30 +3772,30 @@ class PipelineOptimizer(object):
return programs return programs
def _find_post_op(self, ops, cur_op, var_name): #def _find_post_op(self, ops, cur_op, var_name):
""" # """
Find the real post op that has variable named var_name as input. # Find the real post op that has variable named var_name as input.
Args: # Args:
ops (list): A list of ops. # ops (list): A list of ops.
cur_op (Operator): Current operator which has variable named # cur_op (Operator): Current operator which has variable named
var_name as output. # var_name as output.
var_name (string): Variable name. # var_name (string): Variable name.
""" # """
post_op = [] # post_op = []
before = True # before = True
for op in ops: # for op in ops:
if op == cur_op: # if op == cur_op:
before = False # before = False
continue # continue
if before: # if before:
continue # continue
for in_var_name in op.input_arg_names: # for in_var_name in op.input_arg_names:
if in_var_name == var_name: # if in_var_name == var_name:
post_op.append(op) # post_op.append(op)
if post_op: # if post_op:
return post_op[0] # return post_op[0]
return None # return None
def _find_real_prev_op(self, ops, cur_op, var_name): def _find_real_prev_op(self, ops, cur_op, var_name):
""" """
...@@ -3972,7 +3973,7 @@ class PipelineOptimizer(object): ...@@ -3972,7 +3973,7 @@ class PipelineOptimizer(object):
assert self._op_role_var_key in op.attr_names assert self._op_role_var_key in op.attr_names
op_role_var = op.all_attrs()[self._op_role_var_key] op_role_var = op.all_attrs()[self._op_role_var_key]
assert len(op_role_var) == 2 assert len(op_role_var) == 2
param_name = block.vars[op_role_var[0]].name param_name = op_role_var[0]
device = self._param_device_map[param_name] device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device) op._set_attr(self._op_device_key, device)
...@@ -4008,8 +4009,12 @@ class PipelineOptimizer(object): ...@@ -4008,8 +4009,12 @@ class PipelineOptimizer(object):
assert '@RENAME@' in name assert '@RENAME@' in name
assert len(op.desc.output_arg_names()) == 1 assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0] out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(block.ops, op, out_name) assert core.grad_var_suffix() in out_name
device = post_op.attr(self._op_device_key) param_name = self._strip_grad_suffix(out_name)
assert param_name in self._param_device_map
device = self._param_device_map[param_name]
#post_op = self._find_post_op(block.ops, op, out_name)
#device = post_op.attr(self._op_device_key)
assert device assert device
op._set_attr(self._op_device_key, device) op._set_attr(self._op_device_key, device)
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册