diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9e87681c4bef306f7504f8678004fbdec9b7a12e..378902d8dde81388368597709d02d77d7ea7c68f 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4654,15 +4654,22 @@ class PipelineOptimizer(object): op.type == 'elementwise_div'): device = f"{self._device}:all" op._set_attr(self._op_device_key, device) - elif self._is_weight_decay_op(op) and op.type == 'scale': - # set AdamW decay_coeff to device:all - op._set_attr(self._op_device_key, f"{self._device}:all") elif op.type == "alloc_float_status" or op.type == "clear_float_status": op._set_attr(self._op_device_key, f"{self._device}:all") + # NOTE(wangxi): NPU should only clear the float status + # once at each batch step + op._set_attr(self._op_role_key, self._op_role.LRSched) + + float_status_name = op.output_arg_names[0] + float_status_var = block.var(float_status_name) + # FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0) + # while update will exec on sub_scope(last_micro_step), should + # set persistable to use global scope + float_status_var.persistable = True else: other_known_ops = [ 'update_loss_scaling', 'reduce_any', 'concat', 'sum', - 'check_finite_and_unscale', 'alloc_float_status', 'memcpy' + 'check_finite_and_unscale', 'memcpy' ] assert op.type in other_known_ops, "For other ops without " \ "op_device set, they must be one of {}, but it " \