Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
25409dcc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
25409dcc
编写于
6月 08, 2023
作者:
R
ronnywang
提交者:
GitHub
6月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CustomDevice] add sharding support (#54384)
* [CustomDevice] add sarding support * update
上级
3535049a
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
140 addition
and
30 deletion
+140
-30
paddle/fluid/distributed/collective/process_group_custom.cc
paddle/fluid/distributed/collective/process_group_custom.cc
+37
-0
paddle/fluid/distributed/collective/process_group_custom.h
paddle/fluid/distributed/collective/process_group_custom.h
+6
-0
paddle/fluid/pybind/custom_device_py.cc
paddle/fluid/pybind/custom_device_py.cc
+4
-0
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py
.../meta_parallel/sharding/group_sharded_optimizer_stage2.py
+42
-15
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py
...uted/fleet/meta_parallel/sharding/group_sharded_stage3.py
+25
-7
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py
...ted/fleet/meta_parallel/sharding/group_sharded_storage.py
+17
-7
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
...buted/fleet/meta_parallel/sharding/group_sharded_utils.py
+9
-1
未找到文件。
paddle/fluid/distributed/collective/process_group_custom.cc
浏览文件 @
25409dcc
...
@@ -722,6 +722,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
...
@@ -722,6 +722,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
false
,
false
,
false
);
false
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupCustom
::
Reduce
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
ReduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
phi
::
distributed
::
CommStaticCheck
::
SameShape
(
*
out_tensor
,
in_tensor
,
/*dst_rank*/
opts
.
root_rank
,
/*cur_rank*/
rank_
,
size_
,
phi
::
AllocationType
::
CUSTOM
);
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
{
in_tensor
};
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
{
*
out_tensor
};
return
Collective
(
in_wrapper
,
out_wrapper
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
phi
::
ccl
::
CCLComm
comm
,
const
phi
::
stream
::
Stream
&
stream
)
{
phi
::
DeviceManager
::
CCLReduce
(
device_type_
,
input
.
data
(),
output
.
data
(),
input
.
numel
(),
phi
::
ccl
::
ToCCLDataType
(
input
.
dtype
()),
ToCustomCCLRedType
(
opts
.
reduce_op
),
opts
.
root_rank
,
comm
,
stream
);
},
CommType
::
REDUCE
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroupCustom
>
std
::
shared_ptr
<
ProcessGroupCustom
>
ProcessGroupCustom
::
CreateProcessGroupCustom
(
ProcessGroupCustom
::
CreateProcessGroupCustom
(
const
std
::
shared_ptr
<
phi
::
distributed
::
Store
>&
store
,
const
std
::
shared_ptr
<
phi
::
distributed
::
Store
>&
store
,
...
...
paddle/fluid/distributed/collective/process_group_custom.h
浏览文件 @
25409dcc
...
@@ -163,6 +163,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
...
@@ -163,6 +163,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
src_rank
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
src_rank
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
ReduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
protected:
protected:
virtual
std
::
shared_ptr
<
ProcessGroupCustom
::
CustomTask
>
CreateTask
(
virtual
std
::
shared_ptr
<
ProcessGroupCustom
::
CustomTask
>
CreateTask
(
std
::
vector
<
Place
>
places
,
std
::
vector
<
Place
>
places
,
...
...
paddle/fluid/pybind/custom_device_py.cc
浏览文件 @
25409dcc
...
@@ -29,6 +29,10 @@ namespace pybind {
...
@@ -29,6 +29,10 @@ namespace pybind {
void
BindCustomDevicePy
(
py
::
module
*
m_ptr
)
{
void
BindCustomDevicePy
(
py
::
module
*
m_ptr
)
{
auto
&
m
=
*
m_ptr
;
auto
&
m
=
*
m_ptr
;
// Bind Methods
// Bind Methods
m
.
def
(
"_get_device_min_chunk_size"
,
[](
const
std
::
string
&
device_type
)
{
auto
place
=
paddle
::
platform
::
CustomPlace
(
device_type
);
return
phi
::
DeviceManager
::
GetMinChunkSize
(
place
);
});
m
.
def
(
m
.
def
(
"_get_device_total_memory"
,
"_get_device_total_memory"
,
[](
const
std
::
string
&
device_type
,
int
device_id
)
{
[](
const
std
::
string
&
device_type
,
int
device_id
)
{
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py
浏览文件 @
25409dcc
...
@@ -82,8 +82,10 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -82,8 +82,10 @@ class GroupShardedOptimizerStage2(Optimizer):
super
().
__init__
(
learning_rate
=
optim
.
_learning_rate
,
parameters
=
params
)
super
().
__init__
(
learning_rate
=
optim
.
_learning_rate
,
parameters
=
params
)
assert
(
assert
(
core
.
is_compiled_with_cuda
()
or
core
.
is_compiled_with_xpu
()
core
.
is_compiled_with_cuda
()
),
"Only GPU and XPU is supported now"
or
core
.
is_compiled_with_xpu
()
or
(
device
in
core
.
get_all_custom_device_type
())
),
"Only GPU and XPU and CustomDevice is supported now"
# Segmentation information
# Segmentation information
self
.
_dtype_rank_params
=
(
self
.
_dtype_rank_params
=
(
...
@@ -371,6 +373,13 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -371,6 +373,13 @@ class GroupShardedOptimizerStage2(Optimizer):
Count the memory size of the parameters corresponding to rank under the corresponding dtype.
Count the memory size of the parameters corresponding to rank under the corresponding dtype.
"""
"""
# CUDA alignment 256 bytes
# CUDA alignment 256 bytes
if
self
.
_default_device
in
core
.
get_all_custom_device_type
():
device_alignment
=
core
.
libpaddle
.
_get_device_min_chunk_size
(
self
.
_default_device
)
else
:
device_alignment
=
alignment
[
self
.
_default_device
]
if
len
(
self
.
_rank_buffer_size
)
==
0
:
if
len
(
self
.
_rank_buffer_size
)
==
0
:
for
dtype
in
self
.
dtype_rank_params
.
keys
():
for
dtype
in
self
.
dtype_rank_params
.
keys
():
if
dtype
not
in
self
.
_rank_buffer_size
.
keys
():
if
dtype
not
in
self
.
_rank_buffer_size
.
keys
():
...
@@ -384,11 +393,11 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -384,11 +393,11 @@ class GroupShardedOptimizerStage2(Optimizer):
if
not
param
.
trainable
:
if
not
param
.
trainable
:
continue
continue
size
=
param
.
_numel
()
*
align
[
dtype
]
size
=
param
.
_numel
()
*
align
[
dtype
]
remaining
=
size
%
alignment
[
self
.
_default_device
]
remaining
=
size
%
device_alignment
ali
=
(
ali
=
(
0
0
if
remaining
==
0
if
remaining
==
0
else
alignment
[
self
.
_default_device
]
-
remaining
else
device_alignment
-
remaining
)
)
align_
=
ali
//
align
[
dtype
]
align_
=
ali
//
align
[
dtype
]
self
.
_rank_buffer_size
[
dtype
][
dst_rank
]
+=
(
self
.
_rank_buffer_size
[
dtype
][
dst_rank
]
+=
(
...
@@ -439,14 +448,17 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -439,14 +448,17 @@ class GroupShardedOptimizerStage2(Optimizer):
if
self
.
offload
:
if
self
.
offload
:
self
.
_optim
.
_master_weights
=
self
.
_master_params
self
.
_optim
.
_master_weights
=
self
.
_master_params
cpu_master_params
=
list
(
self
.
_master_params
.
values
())
cpu_master_params
=
list
(
self
.
_master_params
.
values
())
if
self
.
_default_device
in
core
.
get_all_custom_device_type
():
device_alignment
=
core
.
libpaddle
.
_get_device_min_chunk_size
(
self
.
_default_device
)
else
:
device_alignment
=
alignment
[
self
.
_default_device
]
for
param
in
cpu_master_params
:
for
param
in
cpu_master_params
:
size
=
param
.
_numel
()
*
align
[
Type
.
fp32
.
value
]
size
=
param
.
_numel
()
*
align
[
Type
.
fp32
.
value
]
remaining
=
size
%
alignment
[
self
.
offload_device
]
remaining
=
size
%
device_alignment
ali
=
(
ali
=
0
if
remaining
==
0
else
device_alignment
-
remaining
0
if
remaining
==
0
else
alignment
[
self
.
offload_device
]
-
remaining
)
align_
=
ali
//
align
[
Type
.
fp32
.
value
]
align_
=
ali
//
align
[
Type
.
fp32
.
value
]
self
.
offload_buffer_size
+=
param
.
_numel
()
+
align_
self
.
offload_buffer_size
+=
param
.
_numel
()
+
align_
self
.
offload_param2align
[
param
.
name
]
=
align_
self
.
offload_param2align
[
param
.
name
]
=
align_
...
@@ -528,6 +540,21 @@ class GroupShardedOptimizerStage2(Optimizer):
...
@@ -528,6 +540,21 @@ class GroupShardedOptimizerStage2(Optimizer):
for
param
in
self
.
_local_params
:
for
param
in
self
.
_local_params
:
if
param
.
name
in
self
.
_master_params
.
keys
():
if
param
.
name
in
self
.
_master_params
.
keys
():
if
(
self
.
_default_device
in
core
.
get_all_custom_device_type
()
):
param
.
set_value
(
self
.
_master_params
[
param
.
name
]
.
_copy_to
(
paddle
.
CustomPlace
(
self
.
_default_device
,
self
.
dev_id
),
True
,
)
.
cast
(
dtype
=
param
.
dtype
)
)
else
:
param
.
set_value
(
param
.
set_value
(
self
.
_master_params
[
param
.
name
]
self
.
_master_params
[
param
.
name
]
.
cuda
(
self
.
dev_id
)
.
cuda
(
self
.
dev_id
)
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py
浏览文件 @
25409dcc
...
@@ -89,7 +89,10 @@ class GroupShardedStage3(nn.Layer):
...
@@ -89,7 +89,10 @@ class GroupShardedStage3(nn.Layer):
super
().
__init__
()
super
().
__init__
()
# Default configs
# Default configs
assert
core
.
is_compiled_with_cuda
(),
"Only support CUDA."
assert
core
.
is_compiled_with_cuda
()
or
(
device
in
core
.
get_all_custom_device_type
()
),
"Only support CUDA / CustomDevice."
self
.
_layer
=
layer
self
.
_layer
=
layer
self
.
_default_device
=
device
self
.
_default_device
=
device
self
.
__sync_buffers
=
sync_buffers
self
.
__sync_buffers
=
sync_buffers
...
@@ -243,6 +246,14 @@ class GroupShardedStage3(nn.Layer):
...
@@ -243,6 +246,14 @@ class GroupShardedStage3(nn.Layer):
else
:
else
:
for
param
in
list
(
self
.
_unslice_params
):
for
param
in
list
(
self
.
_unslice_params
):
param
.
clear_gradient
(
False
)
param
.
clear_gradient
(
False
)
if
(
self
.
_default_device
in
paddle
.
device
.
get_all_custom_device_type
()
):
tmp_var
=
param
.
_copy_to
(
paddle
.
CustomPlace
(
self
.
_default_device
,
DEV_ID
),
True
)
else
:
tmp_var
=
param
.
cuda
(
DEV_ID
)
tmp_var
=
param
.
cuda
(
DEV_ID
)
if
(
if
(
...
@@ -718,10 +729,14 @@ class GroupShardedStage3(nn.Layer):
...
@@ -718,10 +729,14 @@ class GroupShardedStage3(nn.Layer):
def
_param2align
(
self
,
param
):
def
_param2align
(
self
,
param
):
# CUDA alignment 256 bytes
# CUDA alignment 256 bytes
size
=
param
.
_numel
()
*
align
[
param
.
dtype
]
size
=
param
.
_numel
()
*
align
[
param
.
dtype
]
remaining
=
size
%
alignment
[
self
.
_default_device
]
if
self
.
_default_device
in
core
.
get_all_custom_device_type
():
ali
=
(
device_alignment
=
core
.
libpaddle
.
_get_device_min_chunk_size
(
0
if
remaining
==
0
else
alignment
[
self
.
_default_device
]
-
remaining
self
.
_default_device
)
)
else
:
device_alignment
=
alignment
[
self
.
_default_device
]
remaining
=
size
%
device_alignment
ali
=
0
if
remaining
==
0
else
device_alignment
-
remaining
align_
=
ali
//
align
[
param
.
dtype
]
align_
=
ali
//
align
[
param
.
dtype
]
return
align_
return
align_
...
@@ -1095,6 +1110,9 @@ def _device2cpu(trans_param, convert_dtype=False):
...
@@ -1095,6 +1110,9 @@ def _device2cpu(trans_param, convert_dtype=False):
def
_cpu2device
(
param
):
def
_cpu2device
(
param
):
if
DEV
in
paddle
.
device
.
get_all_custom_device_type
():
tmp_p
=
param
.
fw_storage
.
_copy_to
(
paddle
.
CustomPlace
(
DEV
,
DEV_ID
),
True
)
else
:
tmp_p
=
param
.
fw_storage
.
cuda
(
DEV_ID
)
tmp_p
=
param
.
fw_storage
.
cuda
(
DEV_ID
)
if
(
if
(
tmp_p
.
dtype
==
Type
.
fp32
.
value
tmp_p
.
dtype
==
Type
.
fp32
.
value
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py
浏览文件 @
25409dcc
...
@@ -76,6 +76,11 @@ class InternalStorage:
...
@@ -76,6 +76,11 @@ class InternalStorage:
),
"Conversion type is not supported now"
),
"Conversion type is not supported now"
if
self
.
_device
!=
device
:
if
self
.
_device
!=
device
:
if
device
in
paddle
.
device
.
get_all_custom_device_type
():
tmp_buffer
=
self
.
buffer
.
_copy_to
(
paddle
.
CustomPlace
(
device
,
self
.
dev_id
),
True
)
else
:
tmp_buffer
=
(
tmp_buffer
=
(
cvt_to_device
(
self
.
buffer
,
self
.
dev_id
)
cvt_to_device
(
self
.
buffer
,
self
.
dev_id
)
if
device
in
[
"gpu"
,
"xpu"
]
if
device
in
[
"gpu"
,
"xpu"
]
...
@@ -133,6 +138,11 @@ class ParamStorage(InternalStorage):
...
@@ -133,6 +138,11 @@ class ParamStorage(InternalStorage):
cpu_param_shape
.
append
(
p_shape
)
cpu_param_shape
.
append
(
p_shape
)
if
convert_gpu
:
if
convert_gpu
:
if
self
.
_device
in
paddle
.
device
.
get_all_custom_device_type
():
self
.
buffer
=
self
.
buffer
.
_copy_to
(
paddle
.
CustomPlace
(
self
.
_device
,
self
.
dev_id
),
True
)
else
:
# buffer convert from cpu to cuda
# buffer convert from cpu to cuda
self
.
buffer
=
cvt_to_device
(
self
.
buffer
,
self
.
dev_id
)
self
.
buffer
=
cvt_to_device
(
self
.
buffer
,
self
.
dev_id
)
...
...
python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py
浏览文件 @
25409dcc
...
@@ -162,7 +162,13 @@ class GroupShardedClipGrad:
...
@@ -162,7 +162,13 @@ class GroupShardedClipGrad:
# add all reduce to get global norm of distributed params_and_grads
# add all reduce to get global norm of distributed params_and_grads
dev_id
=
int
(
self
.
_device
.
split
(
":"
)[
1
])
dev_id
=
int
(
self
.
_device
.
split
(
":"
)[
1
])
dev_type
=
self
.
_device
.
split
(
':'
)[
0
]
if
paddle
.
device
.
get_device
()
==
"cpu"
:
if
paddle
.
device
.
get_device
()
==
"cpu"
:
if
dev_type
in
paddle
.
device
.
get_all_custom_device_type
():
global_norm_var
=
global_norm_var
.
_copy_to
(
paddle
.
CustomPlace
(
dev_type
,
dev_id
),
True
)
else
:
global_norm_var
=
global_norm_var
.
cuda
(
dev_id
)
global_norm_var
=
global_norm_var
.
cuda
(
dev_id
)
with
device_guard
(
dev_id
,
self
.
_device
.
split
(
":"
)[
0
]):
with
device_guard
(
dev_id
,
self
.
_device
.
split
(
":"
)[
0
]):
...
@@ -207,6 +213,8 @@ def device_guard(dev_id=0, device="cpu"):
...
@@ -207,6 +213,8 @@ def device_guard(dev_id=0, device="cpu"):
paddle
.
set_device
(
device
)
paddle
.
set_device
(
device
)
elif
device
in
[
"gpu"
,
"xpu"
]:
elif
device
in
[
"gpu"
,
"xpu"
]:
paddle
.
set_device
(
f
"
{
device
}
:
{
dev_id
}
"
)
paddle
.
set_device
(
f
"
{
device
}
:
{
dev_id
}
"
)
elif
device
in
paddle
.
device
.
get_all_custom_device_type
():
paddle
.
set_device
(
f
"
{
device
}
:
{
dev_id
}
"
)
try
:
try
:
yield
yield
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录