Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
53182531
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
53182531
编写于
9月 03, 2022
作者:
O
Olatunji Ruwase
提交者:
GitHub
9月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor universal checkpointing and tensor fragments (#2253)
* Refactor universal checkpointing and tensor fragments * Formatting
上级
47e030f5
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
284 addition
and
234 deletion
+284
-234
deepspeed/checkpoint/__init__.py
deepspeed/checkpoint/__init__.py
+2
-0
deepspeed/checkpoint/constants.py
deepspeed/checkpoint/constants.py
+3
-0
deepspeed/checkpoint/universal_checkpoint.py
deepspeed/checkpoint/universal_checkpoint.py
+110
-0
deepspeed/runtime/bf16_optimizer.py
deepspeed/runtime/bf16_optimizer.py
+17
-234
deepspeed/utils/__init__.py
deepspeed/utils/__init__.py
+2
-0
deepspeed/utils/mixed_precision_linkage.py
deepspeed/utils/mixed_precision_linkage.py
+45
-0
deepspeed/utils/tensor_fragment.py
deepspeed/utils/tensor_fragment.py
+105
-0
未找到文件。
deepspeed/checkpoint/__init__.py
浏览文件 @
53182531
...
...
@@ -11,3 +11,5 @@ from .reshape_utils import (merge_state)
from
.reshape_3d_utils
import
(
model_3d_desc
,
get_model_3d_descriptor
)
from
.zero_checkpoint
import
ZeROCheckpoint
from
.universal_checkpoint
import
enable_universal_checkpoint
deepspeed/checkpoint/constants.py
浏览文件 @
53182531
...
...
@@ -21,8 +21,11 @@ FP32_WEIGHT_KEY = "fp32"
#########################################
# Module checkpoint keys
#########################################
PARAM
=
'param'
PARAM_SHAPES
=
'param_shapes'
BUFFER_NAMES
=
'buffer_names'
VOCAB_DIVISIBILITY_PADDING_TENSOR
=
'vocab_divisibility_padding_tensor'
CAT_DIM
=
"cat_dim"
#########################################
# Checkpoint naming constants
...
...
deepspeed/checkpoint/universal_checkpoint.py
0 → 100644
浏览文件 @
53182531
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import
os
import
torch
import
types
from
.constants
import
(
FP32_WEIGHT_KEY
,
PARAM
,
VOCAB_DIVISIBILITY_PADDING_TENSOR
,
CAT_DIM
)
def
load_hp_checkpoint_state
(
self
,
folder
,
tp_rank
,
tp_world_size
):
hp_mapping
=
self
.
_hp_mapping
optim_state_keys
=
hp_mapping
.
get_optim_state_keys
()
hp_keys
=
[
FP32_WEIGHT_KEY
]
+
optim_state_keys
checkpoint_files
=
{
key
:
os
.
path
.
join
(
folder
,
f
"
{
key
}
.pt"
)
for
key
in
hp_keys
}
for
file
in
checkpoint_files
.
values
():
assert
os
.
path
.
isfile
(
file
),
f
'
{
file
}
is not a valid file'
for
key
in
hp_keys
:
ckpt_file
=
checkpoint_files
[
key
]
ckpt_dict
=
torch
.
load
(
ckpt_file
)
full_hp_param
=
ckpt_dict
[
PARAM
]
# need to deal with slices that were averaged.
# the opposite of averaging here becomes an exact copy of the first slice
# I thought of 2 ways:
# implementation a. find a way for a client to pass a dict with patterns
# if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
# tp_rank = 0
# tp_world_size = 1
# the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
# self.shape that means we automatically copy?
# implementation b.
# this version requires no additional data passed from the client
# if the shapes already match it must be slices that were averaged - so we just hack around those
if
full_hp_param
.
shape
==
self
.
shape
:
tp_rank
=
0
tp_world_size
=
1
# special case for word_embeddings weights which get padded differently depending on TP degree.
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor
=
ckpt_dict
.
get
(
VOCAB_DIVISIBILITY_PADDING_TENSOR
,
None
)
if
vocab_divisibility_padding_tensor
is
not
None
:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size
=
self
.
shape
[
0
]
*
tp_world_size
if
padded_target_vocab_size
>
full_hp_param
.
shape
[
0
]:
# Need to expand
padding_tensor
=
vocab_divisibility_padding_tensor
.
expand
(
padded_target_vocab_size
-
full_hp_param
.
shape
[
0
])
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param
=
torch
.
nn
.
functional
.
pad
(
full_hp_param
,
(
0
,
0
,
0
,
padding_tensor
.
shape
[
0
]),
"constant"
,
0
)
full_hp_param
[:
-
padding_tensor
.
shape
[
0
],
:]
=
padding_tensor
else
:
# Need to shrink or keep the same
full_hp_param
=
full_hp_param
[:
padded_target_vocab_size
,
:]
full_param_numel
=
full_hp_param
.
numel
()
tp_slice_numel
=
self
.
numel
()
# if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
# print_rank_0(f'{full_hp_param[:10]=}', force=True)
assert
full_param_numel
==
tp_world_size
*
tp_slice_numel
,
\
f
'Loading
{
ckpt_file
}
full param numel
{
full_param_numel
}
!= tensor slice numel
{
tp_slice_numel
}
* tp_world_size
{
tp_world_size
}
'
dst_tensor
=
hp_mapping
.
hp_fragment
if
key
==
FP32_WEIGHT_KEY
else
hp_mapping
.
get_optim_state_fragment
(
key
)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
chunk_dim
=
ckpt_dict
.
get
(
CAT_DIM
,
0
)
# this performs the opposite of cat when merging TP slices
tp_hp_slice
=
full_hp_param
.
chunk
(
tp_world_size
,
chunk_dim
)[
tp_rank
]
tp_hp_slice
=
tp_hp_slice
.
flatten
()
lp_frag_address
=
hp_mapping
.
lp_fragment_address
tp_hp_fragment
=
tp_hp_slice
.
narrow
(
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
assert
dst_tensor
.
numel
()
==
lp_frag_address
.
numel
,
\
f
'Load checkpoint
{
key
}
dst_tensor numel
{
dst_tensor
.
numel
()
}
!= src numel
{
lp_frag_address
.
numel
}
'
# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor
.
data
.
copy_
(
tp_hp_fragment
.
data
)
def
enable_universal_checkpoint
(
param_list
):
for
param
in
param_list
:
param
.
load_hp_checkpoint_state
=
types
.
MethodType
(
load_hp_checkpoint_state
,
param
)
deepspeed/runtime/bf16_optimizer.py
浏览文件 @
53182531
...
...
@@ -21,162 +21,15 @@ from deepspeed.runtime.utils import (get_global_norm_of_tensors,
is_model_parallel_parameter
,
see_memory_usage
)
from
deepspeed.utils
import
link_hp_params
from
deepspeed.checkpoint
import
enable_universal_checkpoint
from
deepspeed.checkpoint.constants
import
(
DS_VERSION
,
PARTITION_COUNT
,
BASE_OPTIMIZER_STATE
,
SINGLE_PARTITION_OF_FP32_GROUPS
,
CLIP_GRAD
,
GROUP_PADDINGS
,
PARAM_SLICE_MAPPINGS
,
FP32_WEIGHT_KEY
)
import
types
from
dataclasses
import
dataclass
@
dataclass
class
fragment_address
:
numel
:
int
start
:
int
@
dataclass
class
tensor_fragment
:
lp_fragment
:
torch
.
Tensor
lp_fragment_address
:
fragment_address
hp_fragment
:
torch
.
Tensor
hp_fragment_address
:
fragment_address
optim_fragment
:
{}
def
update_hp
(
self
):
self
.
hp_fragment
.
data
.
copy_
(
self
.
lp_fragment
.
data
)
def
update_lp
(
self
):
self
.
lp_fragment
.
data
.
copy_
(
self
.
hp_fragment
.
data
)
def
get_optim_state_fragment
(
self
,
key
):
if
key
in
self
.
optim_fragment
:
return
self
.
optim_fragment
[
key
]
else
:
raise
ValueError
(
f
'
{
key
}
not found in optimizer state fragment'
)
def
get_hp_fragment_address
(
self
):
return
self
.
hp_fragment_address
def
get_optim_state_keys
(
self
):
return
list
(
self
.
optim_fragment
.
keys
())
def
get_full_hp_param
(
self
,
optim_state_key
=
None
):
reduce_buffer
=
torch
.
zeros_like
(
self
,
dtype
=
torch
.
float32
).
flatten
()
if
self
.
_hp_mapping
is
not
None
:
lp_frag_address
=
self
.
_hp_mapping
.
lp_fragment_address
reduce_fragment
=
torch
.
narrow
(
reduce_buffer
,
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
if
optim_state_key
is
None
:
hp_fragment
=
self
.
_hp_mapping
.
hp_fragment
else
:
hp_fragment
=
self
.
_hp_mapping
.
get_optim_state_fragment
(
optim_state_key
)
reduce_fragment
.
data
.
copy_
(
hp_fragment
.
data
)
dist
.
all_reduce
(
reduce_buffer
,
group
=
self
.
_dp_group
)
return
reduce_buffer
.
reshape_as
(
self
)
def
load_hp_checkpoint_state
(
self
,
folder
,
tp_rank
,
tp_world_size
):
hp_mapping
=
self
.
_hp_mapping
optim_state_keys
=
hp_mapping
.
get_optim_state_keys
()
hp_keys
=
[
FP32_WEIGHT_KEY
]
+
optim_state_keys
checkpoint_files
=
{
key
:
os
.
path
.
join
(
folder
,
f
"
{
key
}
.pt"
)
for
key
in
hp_keys
}
for
file
in
checkpoint_files
.
values
():
assert
os
.
path
.
isfile
(
file
),
f
'
{
file
}
is not a valid file'
for
key
in
hp_keys
:
ckpt_file
=
checkpoint_files
[
key
]
ckpt_dict
=
torch
.
load
(
ckpt_file
)
full_hp_param
=
ckpt_dict
[
'param'
]
# need to deal with slices that were averaged.
# the opposite of averaging here becomes an exact copy of the first slice
# I thought of 2 ways:
# implementation a. find a way for a client to pass a dict with patterns
# if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
# tp_rank = 0
# tp_world_size = 1
# the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
# self.shape that means we automatically copy?
# implementation b.
# this version requires no additional data passed from the client
# if the shapes already match it must be slices that were averaged - so we just hack around those
if
full_hp_param
.
shape
==
self
.
shape
:
tp_rank
=
0
tp_world_size
=
1
# special case for word_embeddings weights which get padded differently depending on TP degree.
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor
=
ckpt_dict
.
get
(
'vocab_divisibility_padding_tensor'
,
None
)
if
vocab_divisibility_padding_tensor
is
not
None
:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size
=
self
.
shape
[
0
]
*
tp_world_size
if
padded_target_vocab_size
>
full_hp_param
.
shape
[
0
]:
# Need to expand
padding_tensor
=
vocab_divisibility_padding_tensor
.
expand
(
padded_target_vocab_size
-
full_hp_param
.
shape
[
0
])
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param
=
torch
.
nn
.
functional
.
pad
(
full_hp_param
,
(
0
,
0
,
0
,
padding_tensor
.
shape
[
0
]),
"constant"
,
0
)
full_hp_param
[:
-
padding_tensor
.
shape
[
0
],
:]
=
padding_tensor
else
:
# Need to shrink or keep the same
full_hp_param
=
full_hp_param
[:
padded_target_vocab_size
,
:]
full_param_numel
=
full_hp_param
.
numel
()
tp_slice_numel
=
self
.
numel
()
# if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
# print_rank_0(f'{full_hp_param[:10]=}', force=True)
assert
full_param_numel
==
tp_world_size
*
tp_slice_numel
,
\
f
'Loading
{
ckpt_file
}
full param numel
{
full_param_numel
}
!= tensor slice numel
{
tp_slice_numel
}
* tp_world_size
{
tp_world_size
}
'
dst_tensor
=
hp_mapping
.
hp_fragment
if
key
==
FP32_WEIGHT_KEY
else
hp_mapping
.
get_optim_state_fragment
(
key
)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
chunk_dim
=
ckpt_dict
.
get
(
'cat_dim'
,
0
)
# this performs the opposite of cat when merging TP slices
tp_hp_slice
=
full_hp_param
.
chunk
(
tp_world_size
,
chunk_dim
)[
tp_rank
]
tp_hp_slice
=
tp_hp_slice
.
flatten
()
lp_frag_address
=
hp_mapping
.
lp_fragment_address
tp_hp_fragment
=
tp_hp_slice
.
narrow
(
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
assert
dst_tensor
.
numel
()
==
lp_frag_address
.
numel
,
\
f
'Load checkpoint
{
key
}
dst_tensor numel
{
dst_tensor
.
numel
()
}
!= src numel
{
lp_frag_address
.
numel
}
'
# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor
.
data
.
copy_
(
tp_hp_fragment
.
data
)
PARAM_SLICE_MAPPINGS
)
class
BF16_Optimizer
(
ZeROOptimizer
):
...
...
@@ -327,8 +180,13 @@ class BF16_Optimizer(ZeROOptimizer):
# Need optimizer states initialized before linking lp to optimizer state
self
.
_link_all_hp_params
()
self
.
_enable_universal_checkpoint
()
self
.
_param_slice_mappings
=
self
.
_create_param_mapping
()
def
_enable_universal_checkpoint
(
self
):
for
lp_param_group
in
self
.
bf16_groups
:
enable_universal_checkpoint
(
param_list
=
lp_param_group
)
def
_create_param_mapping
(
self
):
param_mapping
=
[]
for
i
,
_
in
enumerate
(
self
.
optimizer
.
param_groups
):
...
...
@@ -344,93 +202,18 @@ class BF16_Optimizer(ZeROOptimizer):
def
_link_all_hp_params
(
self
):
dp_world_size
=
dist
.
get_world_size
(
group
=
self
.
dp_process_group
)
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
for
i
,
_
in
enumerate
(
self
.
optimizer
.
param_groups
):
# Link bf16 and fp32 params in partition
partition_id
=
dist
.
get_rank
(
group
=
self
.
real_dp_process_group
[
i
])
partition_size
=
self
.
bf16_groups_flat
[
i
].
numel
()
//
dp_world_size
self
.
_link_hp_params
(
self
.
bf16_groups
[
i
],
self
.
fp32_groups_flat_partition
[
i
],
partition_id
*
partition_size
,
partition_size
,
self
.
real_dp_process_group
[
i
])
def
_init_lp_to_hp_mapping
(
self
,
lp_param_list
,
partition_start
,
partition_size
,
dp_group
):
current_offset
=
0
param_and_offset_list
=
[]
partition_end
=
partition_start
+
partition_size
for
lp_param
in
lp_param_list
:
lp_param
.
_hp_mapping
=
None
lp_param
.
_dp_group
=
dp_group
lp_param
.
get_full_hp_param
=
types
.
MethodType
(
get_full_hp_param
,
lp_param
)
lp_param
.
load_hp_checkpoint_state
=
types
.
MethodType
(
load_hp_checkpoint_state
,
lp_param
)
# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
# 2) current_offset + lp_param.numel() >= partition_start
lp_param_end
=
current_offset
+
lp_param
.
numel
()
if
current_offset
<
partition_end
and
lp_param_end
>
partition_start
:
param_and_offset_list
.
append
((
lp_param
,
current_offset
))
current_offset
+=
lp_param
.
numel
()
return
param_and_offset_list
def
_link_hp_params
(
self
,
lp_param_list
,
flat_hp_partition
,
partition_start
,
partition_size
,
dp_group
):
local_lp_param_and_offset
=
self
.
_init_lp_to_hp_mapping
(
lp_param_list
,
partition_start
,
partition_size
,
dp_group
)
hp_end
=
partition_start
+
partition_size
for
lp_param
,
lp_start
in
local_lp_param_and_offset
:
lp_end
=
lp_param
.
numel
()
+
lp_start
hp_start
=
partition_start
fragment_start
=
max
(
lp_start
,
hp_start
)
fragment_end
=
min
(
lp_end
,
hp_end
)
# print(
# f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}'
# )
assert
fragment_start
<
fragment_end
,
\
f
'fragment start
{
fragment_start
}
should be < fragment_end
{
fragment_end
}
'
fragment_numel
=
fragment_end
-
fragment_start
hp_frag_address
=
fragment_address
(
start
=
fragment_start
-
hp_start
,
numel
=
fragment_numel
)
hp_fragment_tensor
=
flat_hp_partition
.
narrow
(
0
,
hp_frag_address
.
start
,
hp_frag_address
.
numel
)
optim_fragment
=
{
key
:
value
.
narrow
(
0
,
hp_frag_address
.
start
,
hp_frag_address
.
numel
)
for
key
,
value
in
self
.
optimizer
.
state
[
flat_hp_partition
].
items
()
if
torch
.
is_tensor
(
value
)
and
value
.
dim
()
>
0
}
lp_frag_address
=
fragment_address
(
start
=
fragment_start
-
lp_start
,
numel
=
fragment_numel
)
lp_fragment_tensor
=
lp_param
.
flatten
().
narrow
(
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
lp_param
.
_hp_mapping
=
tensor_fragment
(
lp_fragment
=
lp_fragment_tensor
,
lp_fragment_address
=
lp_frag_address
,
hp_fragment
=
hp_fragment_tensor
,
hp_fragment_address
=
hp_frag_address
,
optim_fragment
=
optim_fragment
)
flat_hp_partition
=
self
.
fp32_groups_flat_partition
[
i
]
link_hp_params
(
lp_param_list
=
self
.
bf16_groups
[
i
],
flat_hp_partition
=
flat_hp_partition
,
partition_start
=
partition_id
*
partition_size
,
partition_size
=
partition_size
,
partition_optimizer_state
=
self
.
optimizer
.
state
[
flat_hp_partition
],
dp_group
=
self
.
real_dp_process_group
[
i
])
def
initialize_optimizer_states
(
self
):
"""Take an optimizer step with zero-valued gradients to allocate internal
...
...
deepspeed/utils/__init__.py
浏览文件 @
53182531
...
...
@@ -4,4 +4,6 @@ from .comms_logging import get_caller_func
from
.init_on_device
import
OnDevice
from
.groups
import
*
from
.nvtx
import
instrument_w_nvtx
from
.tensor_fragment
import
tensor_fragment
,
get_full_hp_param
,
get_hp_fragment_mapping
from
.mixed_precision_linkage
import
link_hp_params
from
deepspeed.runtime.dataloader
import
RepeatingLoader
deepspeed/utils/mixed_precision_linkage.py
0 → 100644
浏览文件 @
53182531
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import
types
from
deepspeed.utils
import
get_full_hp_param
,
get_hp_fragment_mapping
def
link_hp_params
(
lp_param_list
,
flat_hp_partition
,
partition_start
,
partition_size
,
partition_optimizer_state
,
dp_group
):
local_lp_param_and_offset
=
_init_lp_to_hp_mapping
(
lp_param_list
,
partition_start
,
partition_size
,
dp_group
)
for
lp_param
,
lp_start
in
local_lp_param_and_offset
:
lp_param
.
_hp_mapping
=
get_hp_fragment_mapping
(
lp_param
,
lp_start
,
flat_hp_partition
,
partition_start
,
partition_size
,
partition_optimizer_state
)
def
_init_lp_to_hp_mapping
(
lp_param_list
,
partition_start
,
partition_size
,
dp_group
):
current_offset
=
0
param_and_offset_list
=
[]
partition_end
=
partition_start
+
partition_size
for
lp_param
in
lp_param_list
:
lp_param
.
_hp_mapping
=
None
lp_param
.
_dp_group
=
dp_group
lp_param
.
get_full_hp_param
=
types
.
MethodType
(
get_full_hp_param
,
lp_param
)
# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
# 2) current_offset + lp_param.numel() >= partition_start
lp_param_end
=
current_offset
+
lp_param
.
numel
()
if
current_offset
<
partition_end
and
lp_param_end
>
partition_start
:
param_and_offset_list
.
append
((
lp_param
,
current_offset
))
current_offset
+=
lp_param
.
numel
()
return
param_and_offset_list
deepspeed/utils/tensor_fragment.py
0 → 100644
浏览文件 @
53182531
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import
torch
from
dataclasses
import
dataclass
from
deepspeed
import
comm
as
dist
@
dataclass
class
fragment_address
:
numel
:
int
start
:
int
@
dataclass
class
tensor_fragment
:
lp_fragment
:
torch
.
Tensor
lp_fragment_address
:
fragment_address
hp_fragment
:
torch
.
Tensor
hp_fragment_address
:
fragment_address
optim_fragment
:
{}
def
update_hp
(
self
):
self
.
hp_fragment
.
data
.
copy_
(
self
.
lp_fragment
.
data
)
def
update_lp
(
self
):
self
.
lp_fragment
.
data
.
copy_
(
self
.
hp_fragment
.
data
)
def
get_optim_state_fragment
(
self
,
key
):
if
key
in
self
.
optim_fragment
:
return
self
.
optim_fragment
[
key
]
else
:
raise
ValueError
(
f
'
{
key
}
not found in optimizer state fragment'
)
def
get_hp_fragment_address
(
self
):
return
self
.
hp_fragment_address
def
get_optim_state_keys
(
self
):
return
list
(
self
.
optim_fragment
.
keys
())
def
get_full_hp_param
(
self
,
optim_state_key
=
None
):
reduce_buffer
=
torch
.
zeros_like
(
self
,
dtype
=
torch
.
float32
).
flatten
()
if
self
.
_hp_mapping
is
not
None
:
lp_frag_address
=
self
.
_hp_mapping
.
lp_fragment_address
reduce_fragment
=
torch
.
narrow
(
reduce_buffer
,
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
if
optim_state_key
is
None
:
hp_fragment
=
self
.
_hp_mapping
.
hp_fragment
else
:
hp_fragment
=
self
.
_hp_mapping
.
get_optim_state_fragment
(
optim_state_key
)
reduce_fragment
.
data
.
copy_
(
hp_fragment
.
data
)
dist
.
all_reduce
(
reduce_buffer
,
group
=
self
.
_dp_group
)
return
reduce_buffer
.
reshape_as
(
self
)
def
get_hp_fragment_mapping
(
lp_param
,
lp_start
,
flat_hp_partition
,
partition_start
,
partition_size
,
optimizer_state_dict
):
lp_end
=
lp_param
.
numel
()
+
lp_start
hp_start
=
partition_start
hp_end
=
partition_start
+
partition_size
fragment_start
=
max
(
lp_start
,
hp_start
)
fragment_end
=
min
(
lp_end
,
hp_end
)
# print(
# f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}'
# )
assert
fragment_start
<
fragment_end
,
\
f
'fragment start
{
fragment_start
}
should be < fragment_end
{
fragment_end
}
'
fragment_numel
=
fragment_end
-
fragment_start
hp_frag_address
=
fragment_address
(
start
=
fragment_start
-
hp_start
,
numel
=
fragment_numel
)
hp_fragment_tensor
=
flat_hp_partition
.
narrow
(
0
,
hp_frag_address
.
start
,
hp_frag_address
.
numel
)
optim_fragment
=
{
key
:
value
.
narrow
(
0
,
hp_frag_address
.
start
,
hp_frag_address
.
numel
)
for
key
,
value
in
optimizer_state_dict
.
items
()
if
torch
.
is_tensor
(
value
)
and
value
.
dim
()
>
0
}
lp_frag_address
=
fragment_address
(
start
=
fragment_start
-
lp_start
,
numel
=
fragment_numel
)
lp_fragment_tensor
=
lp_param
.
flatten
().
narrow
(
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
return
tensor_fragment
(
lp_fragment
=
lp_fragment_tensor
,
lp_fragment_address
=
lp_frag_address
,
hp_fragment
=
hp_fragment_tensor
,
hp_fragment_address
=
hp_frag_address
,
optim_fragment
=
optim_fragment
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录