提交 e9737d60 编写于 作者: Y Yancey1989

add a private function to find adam opt pass

上级 da960ada
...@@ -401,10 +401,7 @@ class DistributeTranspiler: ...@@ -401,10 +401,7 @@ class DistributeTranspiler:
# HACK: optimization global ops only used to scale beta1 and beta2 # HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine. # replace it with dependency engine.
for op in self.optimize_ops: for op in self.optimize_ops:
if op.type == "scale": if self._is_adam_connected_op(op):
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id): def __append_optimize_op__(op, block, grad_to_block_id):
...@@ -1152,13 +1149,20 @@ class DistributeTranspiler: ...@@ -1152,13 +1149,20 @@ class DistributeTranspiler:
op.input("Param")[0]), op.input("Param")[0]),
self.origin_program.global_block().var( self.origin_program.global_block().var(
op.input("Grad")[0]))) op.input("Grad")[0])))
elif op.type == "scale": elif self._is_adam_connected_op(op):
# for adam optimize op
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
opt_ops.append(op) opt_ops.append(op)
break
else: else:
pass pass
return opt_ops, params_grads return opt_ops, params_grads
def _is_adam_connected_op(self, op):
"""
A hack function to determinate whether the input operator
is connected to optimize operator.
"""
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
return True
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册