From 73321264c1a5bd29af1e5f14ad312ec07dabe3d0 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 27 Aug 2021 11:19:12 +0800 Subject: [PATCH] [hybrid][npu] fix npu clear float status in pipeline (#35165) --- python/paddle/fluid/optimizer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 478ea754727..eb3d559ddcd 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4654,15 +4654,22 @@ class PipelineOptimizer(object): op.type == 'elementwise_div'): device = f"{self._device}:all" op._set_attr(self._op_device_key, device) - elif self._is_weight_decay_op(op) and op.type == 'scale': - # set AdamW decay_coeff to device:all - op._set_attr(self._op_device_key, f"{self._device}:all") elif op.type == "alloc_float_status" or op.type == "clear_float_status": op._set_attr(self._op_device_key, f"{self._device}:all") + # NOTE(wangxi): NPU should only clear the float status + # once at each batch step + op._set_attr(self._op_role_key, self._op_role.LRSched) + + float_status_name = op.output_arg_names[0] + float_status_var = block.var(float_status_name) + # FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0) + # while update will exec on sub_scope(last_micro_step), should + # set persistable to use global scope + float_status_var.persistable = True else: other_known_ops = [ 'update_loss_scaling', 'reduce_any', 'concat', 'sum', - 'check_finite_and_unscale', 'alloc_float_status', 'memcpy' + 'check_finite_and_unscale', 'memcpy' ] assert op.type in other_known_ops, "For other ops without " \ "op_device set, they must be one of {}, but it " \ -- GitLab