未验证 提交 a0530c3b 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #12123 from jacquesqiao/distribute-transpiler-handle-adam-accumulator

Distribute transpiler handle adam accumulator
...@@ -377,11 +377,6 @@ class DistributeTranspiler(object): ...@@ -377,11 +377,6 @@ class DistributeTranspiler(object):
# append it into the sub program. # append it into the sub program.
global_ops = [] global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
for op in self.optimize_ops:
if self._is_adam_connected_op(op):
global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id, merged_var, def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops): lr_ops):
...@@ -1289,22 +1284,16 @@ class DistributeTranspiler(object): ...@@ -1289,22 +1284,16 @@ class DistributeTranspiler(object):
# If one op's input is another op's output or # If one op's input is another op's output or
# one op's output is another op's input, we say # one op's output is another op's input, we say
# the two operator is connected. # the two operator is connected.
def _append_inname_remove_beta(varname_list): def _append_inname(varname_list):
op_input_names = [] op_input_names = []
for in_name in varname_list: for in_name in varname_list:
# HACK: remove beta1 and beta2 to avoid let all op_input_names.append(in_name)
# ops connected.
if in_name.startswith("beta2_pow_acc") or \
in_name.startswith("beta1_pow_acc"):
continue
else:
op_input_names.append(in_name)
return op_input_names return op_input_names
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names()) op1_input_names = _append_inname(op1.desc.input_arg_names())
op1_output_names = op1.desc.output_arg_names() op1_output_names = op1.desc.output_arg_names()
op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names()) op2_input_names = _append_inname(op2.desc.input_arg_names())
op2_output_names = op2.desc.output_arg_names() op2_output_names = op2.desc.output_arg_names()
if set(op1_output_names) & set(op2_input_names) or \ if set(op1_output_names) & set(op2_input_names) or \
...@@ -1413,7 +1402,7 @@ class DistributeTranspiler(object): ...@@ -1413,7 +1402,7 @@ class DistributeTranspiler(object):
def _get_optimize_pass(self): def _get_optimize_pass(self):
""" """
Get optimizer operators, paramters and gradients from origin_program Get optimizer operators, parameters and gradients from origin_program
Returns: Returns:
opt_ops (list): optimize operators. opt_ops (list): optimize operators.
params_grads (dict): paramter->gradient. params_grads (dict): paramter->gradient.
...@@ -1436,20 +1425,6 @@ class DistributeTranspiler(object): ...@@ -1436,20 +1425,6 @@ class DistributeTranspiler(object):
origin_var_dict[param_name], origin_var_dict[param_name],
origin_var_dict[input_name] origin_var_dict[input_name]
]) ])
elif self._is_adam_connected_op(op):
opt_ops.append(op)
else: else:
pass pass
return opt_ops, params_grads return opt_ops, params_grads
def _is_adam_connected_op(self, op):
"""
A hack function to determinate whether the input operator
is connected to optimize operator.
"""
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
return True
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册