diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 079d90f585abd5d3094dfb281f955c754fd33474..c180e7b21042a1ceb2651d8f7a48a02c1abfd02c 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -401,11 +401,8 @@ class DistributeTranspiler: # HACK: optimization global ops only used to scale beta1 and beta2 # replace it with dependency engine. for op in self.optimize_ops: - if op.type == "scale": - for in_name in op.input_arg_names: - if in_name.startswith("beta1_pow_acc") or \ - in_name.startswith("beta2_pow_acc"): - global_ops.append(op) + if self._is_adam_connected_op(op): + global_ops.append(op) def __append_optimize_op__(op, block, grad_to_block_id): if self._is_opt_op(op): @@ -1152,13 +1149,20 @@ class DistributeTranspiler: op.input("Param")[0]), self.origin_program.global_block().var( op.input("Grad")[0]))) - elif op.type == "scale": - # for adam optimize op - for in_name in op.input_arg_names: - if in_name.startswith("beta1_pow_acc") or \ - in_name.startswith("beta2_pow_acc"): - opt_ops.append(op) - break + elif self._is_adam_connected_op(op): + opt_ops.append(op) else: pass return opt_ops, params_grads + + def _is_adam_connected_op(self, op): + """ + A hack function to determinate whether the input operator + is connected to optimize operator. + """ + if op.type == "scale": + for in_name in op.input_arg_names: + if in_name.startswith("beta1_pow_acc") or \ + in_name.startswith("beta2_pow_acc"): + return True + return False