Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
831db343
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
831db343
编写于
11月 10, 2022
作者:
C
caozhou
提交者:
GitHub
11月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel]Add c_concat pass for reshard (#47809)
* add c_concat pass for reshard * add unittest
上级
d01109fc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
195 addition
and
35 deletion
+195
-35
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+141
-35
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
.../fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
+54
-0
未找到文件。
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
831db343
...
...
@@ -92,6 +92,42 @@ class AllGatherOpDesc:
return
f
"op:
{
self
.
_desc
}
, group:
{
self
.
_group
}
, shape:
{
self
.
_shape
}
, is_bool:
{
self
.
_is_bool
}
."
class
AllGatherConcatOpDesc
:
"""
Describe the c_concat op in the reshard phase.
Args:
group (list): Process group.
shape (list): The tensor shape.
is_bool (bool): Whether c_concat bool data. Default: False.
"""
def
__init__
(
self
,
group
,
shape
,
is_bool
=
False
):
self
.
_group
=
group
self
.
_desc
=
"c_concat"
self
.
_shape
=
shape
self
.
_is_bool
=
is_bool
@
property
def
is_bool
(
self
):
return
self
.
_is_bool
@
property
def
group
(
self
):
return
self
.
_group
@
property
def
desc
(
self
):
return
self
.
_desc
@
property
def
shape
(
self
):
return
self
.
_shape
def
__repr__
(
self
):
return
f
"op:
{
self
.
_desc
}
, group:
{
self
.
_group
}
, shape:
{
self
.
_shape
}
, is_bool:
{
self
.
_is_bool
}
."
class
SendOpDesc
:
"""
Describe the send op in the reshard phase.
...
...
@@ -640,6 +676,46 @@ class Inserter:
tensor_list
.
extend
(
split_out
)
return
tensor_list
,
idx_offset
@
staticmethod
def
insert_c_concat_op
(
block
,
idx
,
tensor
,
ranks
,
op_role
):
"""Insert c_concat op into block at the given index."""
group
=
new_process_group
(
ranks
)
idx_offset
=
0
# insert c_concat op
op_type
=
'c_concat'
# to avoid name conflict with framework
helper
=
LayerHelper
(
op_type
+
"@RESHARD"
,
**
locals
())
with
paddle
.
static
.
program_guard
(
block
.
program
):
c_concat_out
=
block
.
create_var
(
name
=
paddle
.
fluid
.
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
helper
.
name
,
'tmp'
])
),
dtype
=
tensor
.
dtype
,
shape
=
None
,
lod_level
=
tensor
.
lod_level
,
type
=
tensor
.
type
,
persistable
=
False
,
stop_gradient
=
False
,
)
cur_rank
=
paddle
.
distributed
.
get_rank
()
c_concat_op
=
block
.
_insert_op
(
idx
+
idx_offset
,
type
=
op_type
,
inputs
=
{
'X'
:
[
tensor
]},
outputs
=
{
'Out'
:
[
c_concat_out
]},
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'nranks'
:
group
.
nranks
,
'op_role'
:
op_role
,
'rank'
:
group
.
ranks
.
index
(
cur_rank
)
if
cur_rank
in
ranks
else
0
,
},
)
c_concat_op
.
_set_attr
(
'op_namescope'
,
"/auto_parallel/reshard"
)
return
c_concat_out
@
staticmethod
def
concat_partitions_with_op
(
partition_tensor_list
,
tensor
,
partition_index
,
block
,
idx
,
op_role
...
...
@@ -1535,7 +1611,7 @@ class Resharder:
)
)
#
i
n the same process group, it will use allgahther and slice op.
#
I
n the same process group, it will use allgahther and slice op.
else
:
# NOTE: It just supports even partition scene.
partition_index_list
=
[]
...
...
@@ -1599,21 +1675,37 @@ class Resharder:
if
not
serial
else
dist_tensor
.
local_sizes
(
rank
=
process
)
)
op_desc_seq
[
process
]
=
(
[
AllGatherOpDesc
(
group
=
group
,
shape
=
allgather_shape
,
is_bool
=
(
source_tensor
.
dtype
==
paddle
.
bool
),
),
ConcatOpDesc
(
partition_index_list
=
all_partition_index_list
),
slice_op_desc
,
# c_concat pass
if
(
target_dims_mapping
.
count
(
-
1
)
==
len
(
target_dims_mapping
)
and
source_dims_mapping
[:
-
1
].
count
(
-
1
)
==
len
(
source_dims_mapping
[:
-
1
])
and
source_dims_mapping
[
-
1
]
!=
-
1
):
op_desc_seq
[
process
]
=
[
AllGatherConcatOpDesc
(
group
=
group
,
shape
=
allgather_shape
)
]
if
len
(
group
)
>
1
else
[
slice_op_desc
]
)
else
:
op_desc_seq
[
process
]
=
(
[
AllGatherOpDesc
(
group
=
group
,
shape
=
allgather_shape
,
is_bool
=
(
source_tensor
.
dtype
==
paddle
.
bool
),
),
ConcatOpDesc
(
partition_index_list
=
all_partition_index_list
),
slice_op_desc
,
]
if
len
(
group
)
>
1
else
[
slice_op_desc
]
)
return
op_desc_seq
...
...
@@ -1850,27 +1942,41 @@ class Resharder:
)
idx
=
idx_list
[
0
]
elif
isinstance
(
op_desc
,
SliceOpDesc
):
assert
(
len
(
partition_tensor_list
)
==
1
or
not
partition_tensor_list
)
to_slice_tensor
=
(
partition_tensor_list
[
0
][
0
]
if
len
(
partition_tensor_list
)
==
1
else
source_tensor
)
new_name
=
unique_name
.
generate
(
var_name
+
"@RESHARD"
)
target_tensor
=
Inserter
.
insert_slice_op
(
block
,
idx
,
to_slice_tensor
,
starts
=
op_desc
.
starts
,
ends
=
op_desc
.
ends
,
axes
=
op_desc
.
axes
,
new_var_name
=
new_name
,
op_role
=
reshard_op
.
attr
(
'op_role'
),
)
elif
isinstance
(
op_desc
,
SliceOpDesc
)
or
isinstance
(
op_desc
,
AllGatherConcatOpDesc
):
target_tensor
=
None
if
isinstance
(
op_desc
,
SliceOpDesc
):
assert
(
len
(
partition_tensor_list
)
==
1
or
not
partition_tensor_list
)
to_slice_tensor
=
(
partition_tensor_list
[
0
][
0
]
if
len
(
partition_tensor_list
)
==
1
else
source_tensor
)
new_name
=
unique_name
.
generate
(
var_name
+
"@RESHARD"
)
target_tensor
=
Inserter
.
insert_slice_op
(
block
,
idx
,
to_slice_tensor
,
starts
=
op_desc
.
starts
,
ends
=
op_desc
.
ends
,
axes
=
op_desc
.
axes
,
new_var_name
=
new_name
,
op_role
=
reshard_op
.
attr
(
'op_role'
),
)
else
:
target_tensor
=
Inserter
.
insert_c_concat_op
(
block
,
idx
,
source_tensor
,
op_desc
.
group
,
reshard_op
.
attr
(
'op_role'
),
)
assert
target_tensor
is
not
None
process_mesh
=
dist_attr
[
0
]
dims_mapping
=
dist_attr
[
1
]
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
浏览文件 @
831db343
...
...
@@ -304,6 +304,60 @@ class TestMLPReshard(unittest.TestCase):
# the x should not be slice
self
.
assertTrue
(
check_allgather
(
partitioned_main_prog
))
def
test_c_concat
(
self
):
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
],
dim_names
=
[
"x"
])
with
static
.
program_guard
(
train_program
,
startup_program
):
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
4
,
4
],
dtype
=
'float32'
)
x
=
auto
.
shard_tensor
(
x
,
process_mesh
,
[
None
,
"x"
])
w
=
paddle
.
static
.
data
(
name
=
"w"
,
shape
=
[
4
,
4
],
dtype
=
'float32'
)
w
=
auto
.
shard_tensor
(
w
,
process_mesh
,
[
None
,
None
])
y
=
paddle
.
distributed
.
shard_op
(
paddle
.
matmul
,
process_mesh
,
[[
None
,
None
],
[
None
,
None
]]
)(
x
,
w
)
rank_id
=
0
dist_context
=
DistributedContext
()
dist_strategy
=
fleet
.
DistributedStrategy
()
partitioner
=
Partitioner
(
dist_context
,
rank_id
)
completer
=
Completer
(
dist_context
)
complete_train_program
=
completer
.
complete_forward_annotation
(
train_program
)
dist_context
.
block_state
.
parse_forward_blocks
(
complete_train_program
)
(
partitioned_main_prog
,
partitioned_startup_prog
,
partitioned_params_grads
,
)
=
partitioner
.
partition
(
complete_train_program
,
startup_program
,
[])
# test estimator
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
device_count
=
2
)
cost_estimator
=
CostEstimator
(
train_program
,
cluster
)
global_cost
=
cost_estimator
.
estimate
(
dist_context
)
max_memory
=
cost_estimator
.
_estimate_max_memory_by_dist_op
(
dist_context
)
# test cache
global_cost
=
cost_estimator
.
estimate
(
dist_context
)
max_memory
=
cost_estimator
.
_estimate_max_memory_by_dist_op
(
dist_context
)
assert
global_cost
.
time
>=
0
assert
max_memory
>
0
resharder
=
Resharder
(
partitioned_main_prog
,
partitioned_startup_prog
,
rank_id
,
dist_context
,
partitioned_params_grads
,
)
resharder
.
reshard
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录