From 63b58dc2777960a978289a52d872a45bc5447d95 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Thu, 6 Aug 2020 07:46:19 +0000 Subject: [PATCH] update optimizer, test=develop --- python/paddle/fluid/optimizer.py | 63 +++++++++++++++++--------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index e078dbf507a..5ecda645474 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -47,8 +47,9 @@ __all__ = [ 'AdamOptimizer', 'AdamaxOptimizer', 'DpsgdOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta', 'AdadeltaOptimizer', 'ModelAverage', 'LarsMomentum', - 'LarsMomentumOptimizer', 'LambOptimizer', 'ExponentialMovingAverage', - 'PipelineOptimizer', 'LookaheadOptimizer', 'RecomputeOptimizer' + 'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer', + 'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer', + 'RecomputeOptimizer' ] @@ -3771,30 +3772,30 @@ class PipelineOptimizer(object): return programs - def _find_post_op(self, ops, cur_op, var_name): - """ - Find the real post op that has variable named var_name as input. - - Args: - ops (list): A list of ops. - cur_op (Operator): Current operator which has variable named - var_name as output. - var_name (string): Variable name. - """ - post_op = [] - before = True - for op in ops: - if op == cur_op: - before = False - continue - if before: - continue - for in_var_name in op.input_arg_names: - if in_var_name == var_name: - post_op.append(op) - if post_op: - return post_op[0] - return None + #def _find_post_op(self, ops, cur_op, var_name): + # """ + # Find the real post op that has variable named var_name as input. + + # Args: + # ops (list): A list of ops. + # cur_op (Operator): Current operator which has variable named + # var_name as output. + # var_name (string): Variable name. + # """ + # post_op = [] + # before = True + # for op in ops: + # if op == cur_op: + # before = False + # continue + # if before: + # continue + # for in_var_name in op.input_arg_names: + # if in_var_name == var_name: + # post_op.append(op) + # if post_op: + # return post_op[0] + # return None def _find_real_prev_op(self, ops, cur_op, var_name): """ @@ -3972,7 +3973,7 @@ class PipelineOptimizer(object): assert self._op_role_var_key in op.attr_names op_role_var = op.all_attrs()[self._op_role_var_key] assert len(op_role_var) == 2 - param_name = block.vars[op_role_var[0]].name + param_name = op_role_var[0] device = self._param_device_map[param_name] op._set_attr(self._op_device_key, device) @@ -4008,8 +4009,12 @@ class PipelineOptimizer(object): assert '@RENAME@' in name assert len(op.desc.output_arg_names()) == 1 out_name = op.desc.output_arg_names()[0] - post_op = self._find_post_op(block.ops, op, out_name) - device = post_op.attr(self._op_device_key) + assert core.grad_var_suffix() in out_name + param_name = self._strip_grad_suffix(out_name) + assert param_name in self._param_device_map + device = self._param_device_map[param_name] + #post_op = self._find_post_op(block.ops, op, out_name) + #device = post_op.attr(self._op_device_key) assert device op._set_attr(self._op_device_key, device) continue -- GitLab