Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
988c58e5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
988c58e5
编写于
6月 02, 2023
作者:
Z
zhaoyingli
提交者:
GitHub
6月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] Add 1F1B Pass (#54260)
* [AutoParallel] add 1F1B * rm amp
上级
703a64a3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
552 addition
and
43 deletion
+552
-43
python/paddle/distributed/auto_parallel/static/partitioner.py
...on/paddle/distributed/auto_parallel/static/partitioner.py
+1
-1
python/paddle/distributed/auto_parallel/static/reshard.py
python/paddle/distributed/auto_parallel/static/reshard.py
+73
-16
python/paddle/distributed/passes/auto_parallel_pipeline.py
python/paddle/distributed/passes/auto_parallel_pipeline.py
+284
-1
test/auto_parallel/1F1B_pass_unittest.py
test/auto_parallel/1F1B_pass_unittest.py
+115
-0
test/auto_parallel/CMakeLists.txt
test/auto_parallel/CMakeLists.txt
+3
-0
test/auto_parallel/get_gpt_model.py
test/auto_parallel/get_gpt_model.py
+8
-1
test/auto_parallel/optimization_tuner_api.py
test/auto_parallel/optimization_tuner_api.py
+3
-0
test/auto_parallel/test_pass_1F1B.py
test/auto_parallel/test_pass_1F1B.py
+55
-0
test/legacy_test/auto_parallel_gpt_model.py
test/legacy_test/auto_parallel_gpt_model.py
+10
-24
未找到文件。
python/paddle/distributed/auto_parallel/static/partitioner.py
浏览文件 @
988c58e5
...
...
@@ -339,7 +339,7 @@ class Partitioner:
**
{
"grad_var_to_var"
:
grad_var_to_var
},
)
elif
is_optimize_op
(
op
):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must
be
2 because of 1F1B PASS
kinputs
,
koutputs
=
dist_op_context
.
prepare_context
(
op
)
dist_op_opt_impl
=
_get_dist_op_backward_implement
(
op
,
self
.
_dist_context
,
forward_op_id2forward_op
...
...
python/paddle/distributed/auto_parallel/static/reshard.py
浏览文件 @
988c58e5
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
import
copy
from
collections
import
OrderedDict
from
functools
import
reduce
...
...
@@ -1292,6 +1293,8 @@ class Resharder:
shape_x
[
0
]
<=
shape_y
[
0
]
<
shape_x
[
1
]
):
overlapped
=
True
if
shape_x
==
[
0
,
0
]
and
shape_y
==
[
0
,
0
]:
overlapped
=
True
return
overlapped
def
is_unshard
(
self
,
dims_mapping
):
...
...
@@ -1377,6 +1380,14 @@ class Resharder:
# judge whether need reshard by process_mesh
if
tensor_process_mesh
!=
op_process_mesh
:
is_reshard
=
True
# not reshard data in send/recv scene
if
(
tensor_process_mesh
!=
op_process_mesh
and
len
(
tensor_process_mesh
.
process_ids
)
==
len
(
op_process_mesh
.
process_ids
)
and
dist_tensor
.
serial_tensor
.
is_data
):
is_reshard
=
False
else
:
op_output_dims_mapping
=
dist_attr
[
1
]
if
all
(
...
...
@@ -1432,7 +1443,6 @@ class Resharder:
"""
tensor_dist_attr
=
dist_tensor
.
dist_attr
source_tensor
=
dist_tensor
.
serial_tensor
tensor_name
=
source_tensor
.
name
source_dims_mapping
=
tensor_dist_attr
.
dims_mapping
source_process_mesh
=
tensor_dist_attr
.
process_mesh
...
...
@@ -1588,6 +1598,11 @@ class Resharder:
Resharder
.
concat_partitions
(
partition_index_list
,
source_partition_index
)
# TODO(zhaoyingli): Remove the method to a pass.
# Current method to get all pp_ranks' relationship must rely on reshard.
# When reshard insert send/recv pair, the process_group has the pp relationship.
# But the mothod to obtain pp_ranks' relationship is only supported in 'reshard_input',
# casue 'reshard_output' only has current process_group view instead of global view.
if
int
(
op_role
)
==
int
(
OpRole
.
Forward
):
self
.
dist_context
.
up_down_streams
.
add_pair_stream
(
to_send_process
,
target_process
...
...
@@ -1658,10 +1673,10 @@ class Resharder:
if
i
==
0
:
all_partition_index_list
.
append
(
process_index
[
j
][
1
])
for
process
in
group
:
# append slice op desc
slice_starts
=
[]
slice_ends
=
[]
slices_axes
=
[]
min_comm_group
=
copy
.
deepcopy
(
group
)
all_partition_index_list_copied
=
copy
.
deepcopy
(
all_partition_index_list
)
target_partition_index
=
Resharder
.
compute_partition_index
(
process
,
complete_shape
,
...
...
@@ -1669,12 +1684,54 @@ class Resharder:
target_process_shape
,
target_process_group
,
)
for
idx
,
item
in
enumerate
(
target_partition_index
):
slice_starts
.
append
(
item
[
0
])
slice_ends
.
append
(
item
[
1
])
slices_axes
.
append
(
idx
)
for
_process
in
group
:
source_partition_index
=
(
Resharder
.
compute_partition_index
(
_process
,
complete_shape
,
source_dims_mapping
,
source_process_shape
,
source_process_group
,
)
)
if
not
all
(
_
for
_
in
list
(
map
(
self
.
is_overlapped
,
source_partition_index
,
target_partition_index
,
)
)
):
min_comm_group
.
remove
(
_process
)
all_partition_index_list_copied
.
remove
(
source_partition_index
)
concatenated_partition_index_list
=
[]
for
partition_index
in
all_partition_index_list_copied
:
Resharder
.
concat_partitions
(
concatenated_partition_index_list
,
partition_index
)
concatenated_partition_index
=
(
concatenated_partition_index_list
[
0
]
)
to_slice_tensor_shape
=
dist_tensor
.
global_sizes
()
slice_starts
=
[]
slice_ends
=
[]
slices_axes
=
[]
to_slice_tensor_shape
=
[]
for
idx
,
item
in
enumerate
(
concatenated_partition_index
):
slice_starts
.
append
(
target_partition_index
[
idx
][
0
]
-
item
[
0
]
)
slice_ends
.
append
(
target_partition_index
[
idx
][
1
]
-
item
[
0
]
)
slices_axes
.
append
(
idx
)
to_slice_tensor_shape
.
append
(
item
[
1
]
-
item
[
0
])
slice_op_desc
=
SliceOpDesc
(
starts
=
slice_starts
,
ends
=
slice_ends
,
...
...
@@ -1703,18 +1760,18 @@ class Resharder:
op_desc_seq
[
process
]
=
(
[
AllGatherOpDesc
(
group
=
group
,
group
=
min_comm_
group
,
shape
=
allgather_shape
,
is_bool
=
(
source_tensor
.
dtype
==
paddle
.
bool
),
),
ConcatOpDesc
(
partition_index_list
=
all_partition_index_list
partition_index_list
=
all_partition_index_list
_copied
),
slice_op_desc
,
]
if
len
(
group
)
>
1
if
len
(
min_comm_
group
)
>
1
else
[
slice_op_desc
]
)
...
...
@@ -2420,7 +2477,7 @@ class Resharder:
else
:
idx
+=
1
def
_ha
dn
le_recv
(
self
,
block
,
idx
,
var
,
op
,
send_rank
,
recv_rank
):
def
_ha
nd
le_recv
(
self
,
block
,
idx
,
var
,
op
,
send_rank
,
recv_rank
):
if
self
.
rank_id
==
recv_rank
:
# if recv bool data, recv then cast
if
var
.
dtype
==
paddle
.
bool
:
...
...
@@ -2652,7 +2709,7 @@ class Resharder:
)
elif
self
.
rank_id
==
recv_rank
:
# if recv bool data, recv then cast
self
.
_ha
dn
le_recv
(
self
.
_ha
nd
le_recv
(
block
,
idx
,
var
,
...
...
@@ -2684,7 +2741,7 @@ class Resharder:
)
elif
self
.
rank_id
==
recv_rank
:
# if recv bool data, recv then cast
self
.
_ha
dn
le_recv
(
self
.
_ha
nd
le_recv
(
block
,
idx
,
var
,
op
,
item
,
recv_rank
)
else
:
...
...
python/paddle/distributed/passes/auto_parallel_pipeline.py
浏览文件 @
988c58e5
...
...
@@ -17,7 +17,14 @@ import os
from
paddle.distributed.auto_parallel.static.process_group
import
(
remove_process_group
,
)
from
paddle.distributed.auto_parallel.static.utils
import
(
is_backward_op
,
is_forward_op
,
is_lr_sched_op
,
is_optimize_op
,
)
from
paddle.distributed.fleet.fleet_executor_utils
import
TaskNode
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
Parameter
,
Program
...
...
@@ -32,6 +39,12 @@ __not_shape_var_type__ = [
]
def
is_reshard_op
(
op
):
return
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/reshard"
in
op
.
attr
(
'op_namescope'
)
@
register_pass
(
"auto_parallel_pipeline"
)
class
PipelinePass
(
PassBase
):
def
__init__
(
self
):
...
...
@@ -62,7 +75,8 @@ class PipelinePass(PassBase):
self
.
_cur_pp_stage
=
self
.
_get_pp_stage
(
self
.
_cur_rank
)
if
self
.
_mode
==
"1F1B"
:
raise
NotImplementedError
(
"1F1B has not been implemented"
)
self
.
_insert_sync_ops_for_1f1b
()
self
.
_task_1f1b
()
elif
self
.
_mode
==
"F-Then-B"
:
raise
NotImplementedError
(
"F-Then-B has not been implemented"
)
elif
self
.
_mode
==
"stream"
:
...
...
@@ -109,6 +123,98 @@ class PipelinePass(PassBase):
block
.
_sync_with_cpp
()
def
_insert_sync_ops_for_1f1b
(
self
):
"""
This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py.
The difference between this function with 'PipelineOptimizer' is that
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
"""
for
block
in
self
.
_program
.
blocks
:
offset
=
0
first_optimize_index
=
None
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
if
is_optimize_op
(
op
):
first_optimize_index
=
index
break
# insert sync ops
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
# NOTE: pipeline might hang when dynamic_shape is True
if
op
.
type
in
[
'send_v2'
,
'recv_v2'
]:
op
.
_set_attr
(
"dynamic_shape"
,
False
)
# set send op on comm stream
if
op
.
type
==
'send_v2'
:
# step1: set 'use_calc_stream' False
op
.
_set_attr
(
"use_calc_stream"
,
False
)
op_role
=
op
.
attr
(
'op_role'
)
ring_id
=
op
.
attr
(
'ring_id'
)
# step2: insert 'c_sync_calc_stream' op before 'send_v2' op
var_name
=
op
.
input_arg_names
[
0
]
var
=
block
.
var
(
var_name
)
block
.
_insert_op_without_sync
(
index
=
index
+
offset
,
type
=
"c_sync_calc_stream"
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'op_role'
:
op_role
},
)
offset
+=
1
# step3: insert 'c_sync_comm_stream' op after 'send_v2' op or
# before the first optimize op
if
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
index
=
first_optimize_index
+
offset
new_op_role
=
OpRole
.
Optimize
else
:
index
=
index
+
offset
+
1
new_op_role
=
OpRole
.
Backward
sync_comm_op
=
block
.
_insert_op_without_sync
(
index
=
index
,
type
=
"c_sync_comm_stream"
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'op_role'
:
new_op_role
,
'ring_id'
:
ring_id
,
},
)
# step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish
# whether the 'c_sync_comm_stream' op is inserted for pipeline.
if
int
(
op_role
)
==
int
(
OpRole
.
Forward
):
sync_comm_op
.
_set_attr
(
'pipeline_flag'
,
''
)
offset
+=
1
block
.
_sync_with_cpp
()
offset
=
0
backward_recv_index
=
None
for
index
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"recv_v2"
and
is_backward_op
(
op
):
backward_recv_index
=
index
break
if
backward_recv_index
is
None
:
continue
# replace 'c_sync_comm_stream' op with 'nop' op
# use nop op for gc
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
if
index
>=
backward_recv_index
:
break
if
op
.
type
==
'c_sync_comm_stream'
and
op
.
has_attr
(
'pipeline_flag'
):
var_name
=
op
.
output_arg_names
[
0
]
var
=
block
.
var
(
var_name
)
block
.
_remove_op
(
index
+
offset
,
sync
=
False
)
offset
-=
1
block
.
_insert_op_without_sync
(
index
=
backward_recv_index
,
type
=
"nop"
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
},
)
block
.
_sync_with_cpp
()
def
_create_param
(
self
,
dst_block
,
src_var
):
copied_kwargs
=
{}
copied_kwargs
[
'trainable'
]
=
src_var
.
trainable
...
...
@@ -196,6 +302,183 @@ class PipelinePass(PassBase):
break
return
pp_idx
def
_task_1f1b
(
self
):
# create fwd, bwd, opt program with op_role
num_of_functionality
=
4
lr_prog
=
Program
()
fwd_prog
=
Program
()
bwd_prog
=
Program
()
opt_prog
=
Program
()
for
idx
,
src_block
in
enumerate
(
self
.
_program
.
blocks
):
if
idx
==
0
:
lr_block
=
lr_prog
.
block
(
0
)
fwd_block
=
fwd_prog
.
block
(
0
)
bwd_block
=
bwd_prog
.
block
(
0
)
opt_block
=
opt_prog
.
block
(
0
)
else
:
lr_block
=
lr_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
fwd_block
=
fwd_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
bwd_block
=
bwd_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
opt_block
=
opt_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
lr_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
fwd_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
bwd_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
opt_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
# split the program based on the op_role
for
op
in
src_block
.
ops
:
if
is_lr_sched_op
(
op
):
self
.
_create_program
(
src_block
,
lr_block
,
op
)
if
is_forward_op
(
op
):
self
.
_create_program
(
src_block
,
fwd_block
,
op
)
elif
is_backward_op
(
op
):
self
.
_create_program
(
src_block
,
bwd_block
,
op
)
elif
is_optimize_op
(
op
):
self
.
_create_program
(
src_block
,
opt_block
,
op
)
else
:
raise
ValueError
(
"The op role: "
+
str
(
op
.
attr
(
'op_role'
))
+
" isn't one of LRSched, Forward, Backward or Optimizer."
)
lr_prog
.
_sync_with_cpp
()
fwd_prog
.
_sync_with_cpp
()
bwd_prog
.
_sync_with_cpp
()
opt_prog
.
_sync_with_cpp
()
lr_prog
.
_rollback
()
fwd_prog
.
_rollback
()
bwd_prog
.
_rollback
()
opt_prog
.
_rollback
()
# Create task nodes.
lr_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
lr_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
0
),
node_type
=
"Amplifier"
,
lazy_initialize
=
True
,
)
lr_task_node
.
set_run_pre_steps
(
self
.
_acc_steps
)
fwd_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
fwd_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
1
),
node_type
=
"Compute"
,
lazy_initialize
=
True
,
)
bwd_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
bwd_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
2
),
node_type
=
"Compute"
,
lazy_initialize
=
True
,
)
opt_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
opt_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
3
),
node_type
=
"Amplifier"
,
lazy_initialize
=
True
,
)
opt_task_node
.
set_run_pre_steps
(
self
.
_acc_steps
)
opt_task_node
.
set_run_at_offset
(
self
.
_acc_steps
-
1
)
task_nodes
=
{
"lr"
:
lr_task_node
,
"fwd"
:
fwd_task_node
,
"bwd"
:
bwd_task_node
,
"opt"
:
opt_task_node
,
}
# get upstream ranks and downstream ranks of cur_rank
up_down_streams
=
self
.
_dist_context
.
up_down_streams
pp_upstream_ranks
=
up_down_streams
.
ups
(
self
.
_cur_rank
)
pp_downstream_ranks
=
up_down_streams
.
downs
(
self
.
_cur_rank
)
# set upstream/downstream for task_nodes of cur_rank
for
i
,
(
task_role
,
task_node
)
in
enumerate
(
task_nodes
.
items
()):
cur_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
i
)
ups
=
[]
downs
=
[]
# set upstream/downstream and buffersize in pipeline stage
pp_buff_size
=
int
(
self
.
_pp_stages
-
self
.
_cur_pp_stage
)
prev_id
=
cur_id
-
1
next_id
=
cur_id
+
1
if
task_role
!=
"lr"
:
buf_size
=
pp_buff_size
if
task_role
==
"bwd"
else
2
ups
.
append
((
prev_id
,
buf_size
))
if
task_role
!=
"opt"
:
buf_size
=
pp_buff_size
if
task_role
==
"fwd"
else
2
downs
.
append
((
next_id
,
buf_size
))
# set upstream/downstream and buffersize cross pipeline stage
for
upstream
in
pp_upstream_ranks
:
upstream_id
=
int
(
upstream
*
num_of_functionality
+
i
)
if
task_role
==
"fwd"
:
if
upstream
!=
-
1
:
ups
.
append
((
upstream_id
,
2
))
elif
task_role
==
"bwd"
:
if
upstream
!=
-
1
:
downs
.
append
((
upstream_id
,
2
))
for
downstream
in
pp_downstream_ranks
:
downstream_id
=
int
(
downstream
*
num_of_functionality
+
i
)
if
task_role
==
"fwd"
:
if
downstream
!=
-
1
:
downs
.
append
((
downstream_id
,
2
))
elif
task_role
==
"bwd"
:
if
downstream
!=
-
1
:
ups
.
append
((
downstream_id
,
2
))
for
up
in
ups
:
print
(
"Task:"
,
cur_id
,
"'s upstream includes:"
,
up
[
0
],
", buffer size is:"
,
up
[
1
],
)
task_node
.
add_upstream_task
(
up
[
0
],
up
[
1
])
for
down
in
downs
:
print
(
"Task:"
,
cur_id
,
"'s downstream includes:"
,
down
[
0
],
", buffer size is:"
,
down
[
1
],
)
task_node
.
add_downstream_task
(
down
[
0
],
down
[
1
])
# record global message: task_id_to_rank
task_id_to_rank
=
{}
for
i
in
range
(
self
.
_nrank
):
for
j
in
range
(
num_of_functionality
):
task_id_to_rank
[
int
(
i
*
num_of_functionality
+
j
)]
=
i
self
.
_program
.
_pipeline_opt
=
{}
self
.
_program
.
_pipeline_opt
[
'fleet_opt'
]
=
{
"tasks"
:
list
(
task_nodes
.
values
()),
"task_id_to_rank"
:
task_id_to_rank
,
"num_micro_batches"
:
self
.
_acc_steps
,
}
def
_task_stream
(
self
):
num_of_functionality
=
5
start_prog
=
Program
()
...
...
test/auto_parallel/1F1B_pass_unittest.py
0 → 100644
浏览文件 @
988c58e5
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
unittest
import
numpy
as
np
from
get_gpt_model
import
FakeDataset
,
generate_model
import
paddle
from
paddle.distributed
import
ParallelEnv
from
paddle.distributed.fleet
import
auto
paddle
.
enable_static
()
def
apply_pass
(
use_1f1b
=
False
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_1f1b
:
pipeline
=
strategy
.
pipeline
pipeline
.
enable
=
True
pipeline
.
schedule_mode
=
"1F1B"
pipeline
.
accumulate_steps
=
2
else
:
gradient_merge
=
strategy
.
gradient_merge
gradient_merge
.
enable
=
True
gradient_merge
.
k_steps
=
2
gradient_merge
.
avg
=
True
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
Test1F1BPass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rtol
=
1e-5
self
.
atol
=
1e-8
self
.
batch_size
=
2
self
.
batch_num
=
10
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
random
.
seed
(
2021
)
paddle
.
distributed
.
fleet
.
init
(
is_collective
=
True
)
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_1f1b
=
False
):
reset_prog
()
strategy
=
apply_pass
(
use_1f1b
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
self
.
clip_norm
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"pp"
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_results
(
self
,
ref_losses
,
check_losses
):
np
.
testing
.
assert_allclose
(
ref_losses
,
check_losses
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
,
err_msg
=
'pass {} has wrong results!,
\n
u={}
\n
v={}
\n
diff={}'
.
format
(
__class__
,
ref_losses
,
check_losses
,
ref_losses
-
check_losses
),
)
def
test_1f1b_pass
(
self
):
# navie_pp+gradient_merge training
engine_pp
=
self
.
get_engine
()
history_pp
=
engine_pp
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
engine_pp
.
_strategy
.
pipeline
.
enable
is
False
# pp2 1f1b training
engine_1f1b
=
self
.
get_engine
(
True
)
history_1f1b
=
engine_1f1b
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
engine_1f1b
.
_strategy
.
pipeline
.
enable
is
True
# NOTE: every sample data from dataset is all the same
if
paddle
.
distributed
.
get_rank
()
==
1
:
losses_pp
=
np
.
array
(
history_pp
.
history
[
"loss"
])
losses_1f1b
=
np
.
array
(
history_1f1b
.
history
[
"loss"
])
self
.
check_results
(
losses_pp
,
losses_1f1b
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/auto_parallel/CMakeLists.txt
浏览文件 @
988c58e5
...
...
@@ -62,6 +62,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_pass_generation_pipeline
)
set_tests_properties
(
test_pass_generation_pipeline
PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_pass_1F1B MODULES test_pass_1F1B
)
set_tests_properties
(
test_pass_1F1B PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
...
...
test/auto_parallel/get_gpt_model.py
浏览文件 @
988c58e5
...
...
@@ -86,8 +86,14 @@ def generate_model(strategy, dropout_prob=0.0):
modeling
.
_global_parallel_strategy
=
"mp"
elif
strategy
==
"dp"
:
modeling
.
_global_parallel_strategy
=
"dp"
elif
strategy
==
"pp"
:
modeling
.
_global_parallel_strategy
=
"pp"
modeling
.
PP_MESH_LIST
=
[
auto
.
ProcessMesh
(
mesh
=
[
0
]),
auto
.
ProcessMesh
(
mesh
=
[
1
]),
]
else
:
raise
ValueError
(
"Only support serial, mp2
and d
p2."
)
raise
ValueError
(
"Only support serial, mp2
, dp2 and p
p2."
)
gpt
=
GPTModel
(
vocab_size
=
1000
,
...
...
@@ -105,6 +111,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
,
pp_degree
=
2
if
strategy
==
"pp"
else
None
,
)
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
...
...
test/auto_parallel/optimization_tuner_api.py
浏览文件 @
988c58e5
...
...
@@ -85,6 +85,9 @@ def train(fetch):
dist_strategy
=
auto
.
Strategy
()
dist_strategy
.
auto_mode
=
"semi"
# dp optimization config
dp_optimization
=
dist_strategy
.
dp_optimization
dp_optimization
.
enable
=
True
# sharding config
sharding
=
dist_strategy
.
sharding
sharding
.
enable
=
True
...
...
test/auto_parallel/test_pass_1F1B.py
0 → 100644
浏览文件 @
988c58e5
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
import
tempfile
import
unittest
class
Test1F1BPass
(
unittest
.
TestCase
):
def
test_pp2
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"1F1B_pass_unittest.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
tmp_dir
=
tempfile
.
TemporaryDirectory
()
cmd
=
(
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
tmp_dir
.
name
,
launch_model_path
,
]
)
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
tmp_dir
.
cleanup
()
if
__name__
==
"__main__"
:
unittest
.
main
()
test/legacy_test/auto_parallel_gpt_model.py
浏览文件 @
988c58e5
...
...
@@ -350,32 +350,18 @@ class TransformerDecoder(nn.Layer):
output
=
tgt
new_caches
=
[]
self
.
checkpoints
=
[]
if
_global_parallel_strategy
==
"pp"
:
auto
.
shard_tensor
(
output
,
PP_MESH_LIST
[
0
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
if
_global_parallel_strategy
==
"dp_pp"
:
auto
.
shard_tensor
(
output
,
DPPP_MESH_LIST
[
0
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
if
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
output
,
MPPP_MESH_LIST
[
0
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
if
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
output
,
DPMPPP_MESH_LIST
[
0
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
for
i
,
mod
in
enumerate
(
self
.
layers
):
if
_global_parallel_strategy
==
"pp"
:
mod
=
auto
.
shard_op
(
mod
,
PP_MESH_LIST
[
mod
.
mesh_idx
])
elif
_global_parallel_strategy
==
"dp_pp"
:
mod
=
auto
.
shard_op
(
mod
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
])
elif
_global_parallel_strategy
==
"mp_pp"
:
mod
=
auto
.
shard_op
(
mod
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
])
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
mod
=
auto
.
shard_op
(
mod
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
])
if
self
.
use_new_recompute
and
self
.
recompute_granularity
==
"full"
:
mod
=
auto
.
recompute
(
mod
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录