未验证 提交 73321264 编写于 作者: W WangXi 提交者: GitHub

[hybrid][npu] fix npu clear float status in pipeline (#35165)

上级 669853f5
...@@ -4654,15 +4654,22 @@ class PipelineOptimizer(object): ...@@ -4654,15 +4654,22 @@ class PipelineOptimizer(object):
op.type == 'elementwise_div'): op.type == 'elementwise_div'):
device = f"{self._device}:all" device = f"{self._device}:all"
op._set_attr(self._op_device_key, device) op._set_attr(self._op_device_key, device)
elif self._is_weight_decay_op(op) and op.type == 'scale':
# set AdamW decay_coeff to device:all
op._set_attr(self._op_device_key, f"{self._device}:all")
elif op.type == "alloc_float_status" or op.type == "clear_float_status": elif op.type == "alloc_float_status" or op.type == "clear_float_status":
op._set_attr(self._op_device_key, f"{self._device}:all") op._set_attr(self._op_device_key, f"{self._device}:all")
# NOTE(wangxi): NPU should only clear the float status
# once at each batch step
op._set_attr(self._op_role_key, self._op_role.LRSched)
float_status_name = op.output_arg_names[0]
float_status_var = block.var(float_status_name)
# FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0)
# while update will exec on sub_scope(last_micro_step), should
# set persistable to use global scope
float_status_var.persistable = True
else: else:
other_known_ops = [ other_known_ops = [
'update_loss_scaling', 'reduce_any', 'concat', 'sum', 'update_loss_scaling', 'reduce_any', 'concat', 'sum',
'check_finite_and_unscale', 'alloc_float_status', 'memcpy' 'check_finite_and_unscale', 'memcpy'
] ]
assert op.type in other_known_ops, "For other ops without " \ assert op.type in other_known_ops, "For other ops without " \
"op_device set, they must be one of {}, but it " \ "op_device set, they must be one of {}, but it " \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册