From e9737d600f44b13810c91c497a2ce42d96efddfe Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 3 May 2018 11:03:07 +0800 Subject: [PATCH] add a private function to find adam opt pass --- python/paddle/fluid/distribute_transpiler.py | 28 +++++++++++--------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 079d90f585..c180e7b210 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -401,11 +401,8 @@ class DistributeTranspiler: # HACK: optimization global ops only used to scale beta1 and beta2 # replace it with dependency engine. for op in self.optimize_ops: - if op.type == "scale": - for in_name in op.input_arg_names: - if in_name.startswith("beta1_pow_acc") or \ - in_name.startswith("beta2_pow_acc"): - global_ops.append(op) + if self._is_adam_connected_op(op): + global_ops.append(op) def __append_optimize_op__(op, block, grad_to_block_id): if self._is_opt_op(op): @@ -1152,13 +1149,20 @@ class DistributeTranspiler: op.input("Param")[0]), self.origin_program.global_block().var( op.input("Grad")[0]))) - elif op.type == "scale": - # for adam optimize op - for in_name in op.input_arg_names: - if in_name.startswith("beta1_pow_acc") or \ - in_name.startswith("beta2_pow_acc"): - opt_ops.append(op) - break + elif self._is_adam_connected_op(op): + opt_ops.append(op) else: pass return opt_ops, params_grads + + def _is_adam_connected_op(self, op): + """ + A hack function to determinate whether the input operator + is connected to optimize operator. + """ + if op.type == "scale": + for in_name in op.input_arg_names: + if in_name.startswith("beta1_pow_acc") or \ + in_name.startswith("beta2_pow_acc"): + return True + return False -- GitLab