From 5c12c5eb421996ab25553a2cef488ed95f993aff Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 12 Jul 2018 17:32:45 +0800 Subject: [PATCH] update distribute transformer for adam and adamax optimizer --- .../fluid/transpiler/distribute_transpiler.py | 35 +++---------------- 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 53d6ca86a0..92cdff04a0 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -377,11 +377,6 @@ class DistributeTranspiler(object): # append it into the sub program. global_ops = [] - # HACK: optimization global ops only used to scale beta1 and beta2 - # replace it with dependency engine. - for op in self.optimize_ops: - if self._is_adam_connected_op(op): - global_ops.append(op) def __append_optimize_op__(op, block, grad_to_block_id, merged_var, lr_ops): @@ -1289,22 +1284,16 @@ class DistributeTranspiler(object): # If one op's input is another op's output or # one op's output is another op's input, we say # the two operator is connected. - def _append_inname_remove_beta(varname_list): + def _append_inname(varname_list): op_input_names = [] for in_name in varname_list: - # HACK: remove beta1 and beta2 to avoid let all - # ops connected. - if in_name.startswith("beta2_pow_acc") or \ - in_name.startswith("beta1_pow_acc"): - continue - else: - op_input_names.append(in_name) + op_input_names.append(in_name) return op_input_names - op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names()) + op1_input_names = _append_inname(op1.desc.input_arg_names()) op1_output_names = op1.desc.output_arg_names() - op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names()) + op2_input_names = _append_inname(op2.desc.input_arg_names()) op2_output_names = op2.desc.output_arg_names() if set(op1_output_names) & set(op2_input_names) or \ @@ -1413,7 +1402,7 @@ class DistributeTranspiler(object): def _get_optimize_pass(self): """ - Get optimizer operators, paramters and gradients from origin_program + Get optimizer operators, parameters and gradients from origin_program Returns: opt_ops (list): optimize operators. params_grads (dict): paramter->gradient. @@ -1436,20 +1425,6 @@ class DistributeTranspiler(object): origin_var_dict[param_name], origin_var_dict[input_name] ]) - elif self._is_adam_connected_op(op): - opt_ops.append(op) else: pass return opt_ops, params_grads - - def _is_adam_connected_op(self, op): - """ - A hack function to determinate whether the input operator - is connected to optimize operator. - """ - if op.type == "scale": - for in_name in op.input_arg_names: - if in_name.startswith("beta1_pow_acc") or \ - in_name.startswith("beta2_pow_acc"): - return True - return False -- GitLab