提交 5c12c5eb 编写于 作者: Q qiaolongfei

update distribute transformer for adam and adamax optimizer

上级 6ff7f238
......@@ -377,11 +377,6 @@ class DistributeTranspiler(object):
# append it into the sub program.
global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
for op in self.optimize_ops:
if self._is_adam_connected_op(op):
global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
lr_ops):
......@@ -1289,22 +1284,16 @@ class DistributeTranspiler(object):
# If one op's input is another op's output or
# one op's output is another op's input, we say
# the two operator is connected.
def _append_inname_remove_beta(varname_list):
def _append_inname(varname_list):
op_input_names = []
for in_name in varname_list:
# HACK: remove beta1 and beta2 to avoid let all
# ops connected.
if in_name.startswith("beta2_pow_acc") or \
in_name.startswith("beta1_pow_acc"):
continue
else:
op_input_names.append(in_name)
return op_input_names
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
op1_input_names = _append_inname(op1.desc.input_arg_names())
op1_output_names = op1.desc.output_arg_names()
op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
op2_input_names = _append_inname(op2.desc.input_arg_names())
op2_output_names = op2.desc.output_arg_names()
if set(op1_output_names) & set(op2_input_names) or \
......@@ -1413,7 +1402,7 @@ class DistributeTranspiler(object):
def _get_optimize_pass(self):
"""
Get optimizer operators, paramters and gradients from origin_program
Get optimizer operators, parameters and gradients from origin_program
Returns:
opt_ops (list): optimize operators.
params_grads (dict): paramter->gradient.
......@@ -1436,20 +1425,6 @@ class DistributeTranspiler(object):
origin_var_dict[param_name],
origin_var_dict[input_name]
])
elif self._is_adam_connected_op(op):
opt_ops.append(op)
else:
pass
return opt_ops, params_grads
def _is_adam_connected_op(self, op):
"""
A hack function to determinate whether the input operator
is connected to optimize operator.
"""
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
return True
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册