Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7ef1de67
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7ef1de67
编写于
4月 25, 2021
作者:
S
ShenLiang
提交者:
GitHub
4月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[HybridParallel] Add pipeline layer in dygraph (#32449)
* add pipeline layer
上级
976fe6f9
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
379 addition
and
10 deletion
+379
-10
python/paddle/distributed/fleet/__init__.py
python/paddle/distributed/fleet/__init__.py
+1
-1
python/paddle/distributed/fleet/base/topology.py
python/paddle/distributed/fleet/base/topology.py
+24
-3
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
...ptimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
+1
-1
python/paddle/distributed/fleet/meta_parallel/__init__.py
python/paddle/distributed/fleet/meta_parallel/__init__.py
+1
-1
python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py
...stributed/fleet/meta_parallel/parallel_layers/__init__.py
+2
-1
python/paddle/distributed/fleet/meta_parallel/parallel_layers/layers_help.py
...ibuted/fleet/meta_parallel/parallel_layers/layers_help.py
+0
-0
python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/mp_layers.py
+0
-0
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+156
-0
python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
...distributed/fleet/meta_parallel/parallel_layers/random.py
+0
-0
python/paddle/distributed/fleet/utils/log_util.py
python/paddle/distributed/fleet/utils/log_util.py
+13
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-2
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py
.../paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py
+148
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_layer.py
...d/tests/unittests/test_parallel_dygraph_pipeline_layer.py
+29
-0
python/setup.py.in
python/setup.py.in
+1
-1
未找到文件。
python/paddle/distributed/fleet/__init__.py
浏览文件 @
7ef1de67
...
@@ -21,7 +21,7 @@ from .dataset import *
...
@@ -21,7 +21,7 @@ from .dataset import *
from
.data_generator
import
MultiSlotDataGenerator
,
MultiSlotStringDataGenerator
from
.data_generator
import
MultiSlotDataGenerator
,
MultiSlotStringDataGenerator
from
.
import
metrics
from
.
import
metrics
from
.base.topology
import
CommunicateTopology
,
HybridCommunicateGroup
from
.base.topology
import
CommunicateTopology
,
HybridCommunicateGroup
from
.meta_parallel
import
random
,
layers
from
.meta_parallel
import
*
__all__
=
[
__all__
=
[
"DistributedStrategy"
,
"UtilBase"
,
"UserDefinedRoleMaker"
,
"DistributedStrategy"
,
"UtilBase"
,
"UserDefinedRoleMaker"
,
...
...
python/paddle/distributed/fleet/base/topology.py
浏览文件 @
7ef1de67
...
@@ -120,6 +120,7 @@ class HybridCommunicateGroup(object):
...
@@ -120,6 +120,7 @@ class HybridCommunicateGroup(object):
self
.
_data_parallel_id
=
self
.
_get_data_parallel_id
()
self
.
_data_parallel_id
=
self
.
_get_data_parallel_id
()
self
.
_model_parallel_id
=
self
.
_get_model_parallel_id
()
self
.
_model_parallel_id
=
self
.
_get_model_parallel_id
()
self
.
stage_id
=
self
.
_get_pipe_parallel_id
()
assert
self
.
_check_vaild_topo
(
assert
self
.
_check_vaild_topo
(
),
"Here is an unreasonable topogy setting. world_size: {}, but"
\
),
"Here is an unreasonable topogy setting. world_size: {}, but"
\
...
@@ -132,15 +133,22 @@ class HybridCommunicateGroup(object):
...
@@ -132,15 +133,22 @@ class HybridCommunicateGroup(object):
# create comm group for model parallel
# create comm group for model parallel
self
.
_mp_group
,
self
.
_mp_comm_group
=
self
.
_set_comm_group
(
"model"
)
self
.
_mp_group
,
self
.
_mp_comm_group
=
self
.
_set_comm_group
(
"model"
)
# create comm group for pipe parallel
self
.
_pp_group
,
self
.
_pp_comm_group
=
self
.
_set_comm_group
(
"pipe"
)
# create global group for check inf_nan / clip global norm
# create global group for check inf_nan / clip global norm
self
.
_check_group
,
self
.
_check_comm_group
=
self
.
_set_check_group
(
self
.
_check_group
,
self
.
_check_comm_group
=
self
.
_set_check_group
(
"data"
)
"data"
)
# create p2p group
self
.
is_first_stage
=
(
self
.
stage_id
==
0
)
self
.
is_last_stage
=
(
self
.
stage_id
==
(
self
.
_pp_degree
-
1
))
debug_str
=
"HybridParallelInfo: rank_id: %d, dp_degree: %d, "
\
debug_str
=
"HybridParallelInfo: rank_id: %d, dp_degree: %d, "
\
"mp_degree: %d, pp_degree: %d
\n
"
%
(
self
.
global_rank
,
self
.
_dp_degree
,
"mp_degree: %d, pp_degree: %d"
%
(
self
.
global_rank
,
self
.
_dp_degree
,
self
.
_mp_degree
,
self
.
_pp_degree
)
self
.
_mp_degree
,
self
.
_pp_degree
)
debug_str
+=
"dp_group: %s, mp_group: %s, check/clip group: %s"
%
(
debug_str
+=
"dp_group: %s, mp_group: %s,
pp_group: %s,
check/clip group: %s"
%
(
self
.
_dp_group
,
self
.
_mp_group
,
self
.
_check_group
)
self
.
_dp_group
,
self
.
_mp_group
,
self
.
_
pp_group
,
self
.
_
check_group
)
logger
.
info
(
debug_str
)
logger
.
info
(
debug_str
)
global
_HYBRID_PARALLEL_GROUP
global
_HYBRID_PARALLEL_GROUP
...
@@ -229,6 +237,19 @@ class HybridCommunicateGroup(object):
...
@@ -229,6 +237,19 @@ class HybridCommunicateGroup(object):
def
get_model_parallel_group_src_rank
(
self
):
def
get_model_parallel_group_src_rank
(
self
):
return
self
.
_mp_comm_group
.
ranks
[
0
]
return
self
.
_mp_comm_group
.
ranks
[
0
]
# pipeline parallel message
def
_get_pipe_parallel_id
(
self
):
return
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
def
get_stage_id
(
self
):
return
self
.
stage_id
def
get_pipe_parallel_world_size
(
self
):
return
self
.
_pp_degree
def
get_pipe_parallel_group
(
self
):
return
self
.
_pp_comm_group
# check parallel group
# check parallel group
def
get_check_parallel_group
(
self
):
def
get_check_parallel_group
(
self
):
return
self
.
_check_comm_group
return
self
.
_check_comm_group
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
浏览文件 @
7ef1de67
...
@@ -66,7 +66,7 @@ class HybridParallelGradScaler:
...
@@ -66,7 +66,7 @@ class HybridParallelGradScaler:
self
.
_found_inf
)
self
.
_found_inf
)
# allreduce_max found_inf in check_group
# allreduce_max found_inf in check_group
if
self
.
_is_mp
:
if
self
.
_is_mp
:
self
.
_found_inf
=
paddle
.
cast
(
self
.
_found_inf
,
dtype
=
"int
64
"
)
self
.
_found_inf
=
paddle
.
cast
(
self
.
_found_inf
,
dtype
=
"int
32
"
)
paddle
.
distributed
.
all_reduce
(
paddle
.
distributed
.
all_reduce
(
self
.
_found_inf
,
self
.
_found_inf
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
,
op
=
paddle
.
distributed
.
ReduceOp
.
MAX
,
...
...
python/paddle/distributed/fleet/meta_parallel/__init__.py
浏览文件 @
7ef1de67
...
@@ -12,5 +12,5 @@
...
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.
mp_util
s
import
*
from
.
parallel_layer
s
import
*
from
.model_parallel
import
ModelParallel
from
.model_parallel
import
ModelParallel
python/paddle/distributed/fleet/meta_parallel/
mp_util
s/__init__.py
→
python/paddle/distributed/fleet/meta_parallel/
parallel_layer
s/__init__.py
浏览文件 @
7ef1de67
...
@@ -12,5 +12,6 @@
...
@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.layers
import
*
from
.mp_layers
import
*
from
.pp_layers
import
*
from
.random
import
*
from
.random
import
*
python/paddle/distributed/fleet/meta_parallel/
mp_util
s/layers_help.py
→
python/paddle/distributed/fleet/meta_parallel/
parallel_layer
s/layers_help.py
浏览文件 @
7ef1de67
文件已移动
python/paddle/distributed/fleet/meta_parallel/
mp_utils/
layers.py
→
python/paddle/distributed/fleet/meta_parallel/
parallel_layers/mp_
layers.py
浏览文件 @
7ef1de67
文件已移动
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
0 → 100644
浏览文件 @
7ef1de67
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
paddle
from
paddle.fluid.dygraph.layers
import
Layer
from
...utils.log_util
import
logger
,
layer_to_str
__all__
=
[
'LayerDesc'
,
'PipelineLayer'
]
class
SegmentLayers
(
object
):
def
__init__
(
self
,
layers_desc
,
num_parts
,
method
=
"uniform"
):
self
.
_layers_desc
=
layers_desc
self
.
method
=
method
self
.
num_parts
=
num_parts
self
.
num_items
=
len
(
layers_desc
)
assert
self
.
num_items
>=
self
.
num_parts
,
"layer number should be greater than number of segments"
def
do_segment
(
self
):
if
self
.
method
==
"uniform"
:
return
self
.
uniform
(
self
.
num_items
,
self
.
num_parts
)
def
uniform
(
self
,
num_items
,
num_parts
):
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
part_size
=
math
.
floor
(
num_items
/
num_parts
)
for
i
in
range
(
num_parts
):
result
[
i
]
=
int
(
min
(
part_size
*
i
,
num_items
))
result
[
num_parts
]
=
num_items
return
result
class
LayerDesc
(
object
):
def
__init__
(
self
,
layer_func
,
*
inputs
,
**
kwargs
):
self
.
layer_func
=
layer_func
self
.
inputs
=
inputs
self
.
kwargs
=
kwargs
if
not
issubclass
(
layer_func
,
Layer
):
raise
TypeError
(
"The input(layer_func) should be a derived class of Layer."
)
def
build_layer
(
self
):
return
self
.
layer_func
(
*
self
.
inputs
,
**
self
.
kwargs
)
def
__repr__
(
self
):
return
layer_to_str
(
self
.
layer_func
.
__name__
,
*
self
.
inputs
,
**
self
.
kwargs
)
class
PipelineLayer
(
Layer
):
def
__init__
(
self
,
layers
,
num_stages
=
None
,
topology
=
None
,
loss_fn
=
None
,
seg_method
=
"uniform"
):
super
(
PipelineLayer
,
self
).
__init__
()
if
num_stages
is
None
and
topology
is
None
:
raise
ValueError
(
"should provide num_stages or topology"
)
# lazy import
import
paddle.distributed
as
dist
from
paddle.distributed
import
fleet
self
.
device_id
=
dist
.
ParallelEnv
().
device_id
self
.
layers
=
layers
self
.
_loss_fn
=
loss_fn
self
.
_topo
=
topology
word_size
=
dist
.
get_world_size
()
self
.
global_rank
=
dist
.
get_rank
()
if
self
.
_topo
:
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
if
num_stages
:
assert
self
.
_num_stages
==
num_stages
,
"num_stages should be equal to be %d"
%
(
self
.
_num_stages
)
else
:
# construct default topology
if
word_size
%
num_stages
!=
0
:
raise
ValueError
(
"should provide correct num_stages({}) "
"which can be divided by word_size({})"
.
format
(
num_stages
,
word_size
))
dp_num
=
word_size
//
num_stages
self
.
_topo
=
fleet
.
CommunicateTopology
([
"data"
,
"pipe"
,
"model"
],
[
dp_num
,
num_stages
,
1
])
self
.
_stage_id
=
self
.
_topo
.
get_coord
(
self
.
global_rank
).
pipe
self
.
_num_stages
=
self
.
_topo
.
get_dim_size
(
"pipe"
)
# initialize segment
self
.
_layers_desc
=
list
(
self
.
layers
)
self
.
_num_layers
=
len
(
self
.
_layers_desc
)
self
.
_start_pos
=
0
self
.
_end_pos
=
self
.
_num_layers
-
1
self
.
_segment_network
(
seg_method
)
# construct layer
self
.
run_function
=
[]
self
.
_build_layer
()
self
.
to
(
paddle
.
CUDAPlace
(
self
.
device_id
))
def
_segment_network
(
self
,
seg_method
):
logger
.
info
(
"start segment network.."
)
seg
=
SegmentLayers
(
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
)
self
.
segment_parts
=
seg
.
do_segment
()
self
.
_start_pos
=
self
.
segment_parts
[
self
.
_stage_id
]
self
.
_end_pos
=
self
.
segment_parts
[
self
.
_stage_id
+
1
]
# print information for debug
for
stage
in
range
(
self
.
_num_stages
):
start
=
self
.
segment_parts
[
stage
]
end
=
self
.
segment_parts
[
stage
+
1
]
logger
.
info
(
"stage={}, global_rank={} ,layer_number={}"
.
format
(
stage
,
self
.
global_rank
,
end
-
start
))
for
index
,
layer
in
enumerate
(
self
.
_layers_desc
[
start
:
end
]):
logger
.
info
(
"{}: {}"
.
format
(
index
+
start
,
str
(
layer
)))
if
self
.
_loss_fn
:
try
:
logger
.
info
(
"loss: {}"
.
format
(
self
.
_loss_fn
.
__name__
))
except
AttributeError
:
logger
.
info
(
"loss: {}"
.
format
(
self
.
_loss_fn
.
__class__
.
__name__
))
def
_build_layer
(
self
):
start
=
self
.
_start_pos
end
=
self
.
_end_pos
for
index
,
layer
in
enumerate
(
self
.
_layers_desc
[
start
:
end
]):
layer_index
=
start
+
index
if
isinstance
(
layer
,
Layer
):
self
.
run_function
.
append
(
layer
)
self
.
add_sublayer
(
str
(
layer_index
),
layer
)
elif
isinstance
(
layer
,
LayerDesc
):
model
=
layer
.
build_layer
()
self
.
run_function
.
append
(
model
)
self
.
add_sublayer
(
str
(
layer_index
),
model
)
else
:
self
.
run_function
.
append
(
layer
)
def
forward
(
self
,
input
):
for
layer
in
self
.
run_function
:
input
=
layer
(
input
)
return
input
python/paddle/distributed/fleet/meta_parallel/
mp_util
s/random.py
→
python/paddle/distributed/fleet/meta_parallel/
parallel_layer
s/random.py
浏览文件 @
7ef1de67
文件已移动
python/paddle/distributed/fleet/utils/log_util.py
浏览文件 @
7ef1de67
...
@@ -36,3 +36,16 @@ class LoggerFactory:
...
@@ -36,3 +36,16 @@ class LoggerFactory:
logger
=
LoggerFactory
.
build_logger
(
name
=
"HybridParallel"
,
level
=
logging
.
INFO
)
logger
=
LoggerFactory
.
build_logger
(
name
=
"HybridParallel"
,
level
=
logging
.
INFO
)
def
layer_to_str
(
base
,
*
args
,
**
kwargs
):
name
=
base
+
"("
if
args
:
name
+=
", "
.
join
(
str
(
arg
)
for
arg
in
args
)
if
kwargs
:
name
+=
", "
if
kwargs
:
name
+=
", "
.
join
(
"{}={}"
.
format
(
key
,
str
(
value
))
for
key
,
value
in
kwargs
.
items
())
name
+=
")"
return
name
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
7ef1de67
...
@@ -22,7 +22,7 @@ list(APPEND DIST_TEST_OPS test_gen_nccl_id_op)
...
@@ -22,7 +22,7 @@ list(APPEND DIST_TEST_OPS test_gen_nccl_id_op)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel
)
#list(APPEND DIST_TEST_OPS test_parallel_dygraph_hybrid_parallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_layer
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
#remove distribute unittests.
#remove distribute unittests.
list
(
APPEND MIXED_DIST_TEST_OPS test_dgc_op
)
list
(
APPEND MIXED_DIST_TEST_OPS test_dgc_op
)
...
@@ -173,6 +173,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
...
@@ -173,6 +173,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm
)
LIST
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_layer
)
LIST
(
REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision
)
LIST
(
REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision
)
LIST
(
REMOVE_ITEM TEST_OPS test_fleet_base_single
)
LIST
(
REMOVE_ITEM TEST_OPS test_fleet_base_single
)
elseif
(
WITH_GPU
)
elseif
(
WITH_GPU
)
...
@@ -857,7 +858,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
...
@@ -857,7 +858,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120
)
#set_tests_properties(test_parallel_dygraph_hybrid_parallel PROPERTIES TIMEOUT 200 LABELS "RUN_TYPE=DIST"
)
set_tests_properties
(
test_parallel_dygraph_pipeline_layer PROPERTIES TIMEOUT 120
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py
0 → 100644
浏览文件 @
7ef1de67
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
os
import
paddle
from
paddle.distributed
import
fleet
import
copy
from
paddle.fluid.dygraph.container
import
Sequential
import
paddle.nn
as
nn
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.distributed.fleet.meta_parallel
import
LayerDesc
,
PipelineLayer
import
paddle.nn.functional
as
F
import
unittest
class
AlexNet
(
Layer
):
def
__init__
(
self
,
num_classes
=
10
):
super
(
AlexNet
,
self
).
__init__
()
self
.
features
=
Sequential
(
nn
.
Conv2D
(
3
,
64
,
kernel_size
=
11
,
stride
=
4
,
padding
=
5
),
nn
.
ReLU
(),
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
nn
.
Conv2D
(
64
,
192
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
nn
.
Conv2D
(
192
,
384
,
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(),
nn
.
Conv2D
(
384
,
256
,
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(),
nn
.
Conv2D
(
256
,
256
,
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(),
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
)
self
.
classifier
=
nn
.
Linear
(
256
,
num_classes
)
self
.
loss_fn
=
nn
.
loss
.
CrossEntropyLoss
()
def
forward
(
self
,
x
,
y
):
x
=
self
.
features
(
x
)
x
.
flatten
()
x
=
self
.
classifier
(
x
)
return
self
.
loss_fn
(
x
,
y
)
class
AlexNetPipe
(
AlexNet
):
def
to_layers
(
self
):
feat
=
[
self
.
features
[
i
]
for
i
in
range
(
len
(
self
.
features
))]
loss_fn
=
[
lambda
x
:
x
.
flatten
(),
self
.
classifier
]
feat
.
extend
(
loss_fn
)
return
feat
class
AlexNetPipeDesc
(
PipelineLayer
):
def
__init__
(
self
,
num_classes
=
10
,
**
kwargs
):
self
.
num_classes
=
num_classes
decs
=
[
LayerDesc
(
nn
.
Conv2D
,
3
,
64
,
kernel_size
=
11
,
stride
=
4
,
padding
=
5
),
LayerDesc
(
nn
.
ReLU
),
LayerDesc
(
nn
.
MaxPool2D
,
kernel_size
=
2
,
stride
=
2
),
LayerDesc
(
nn
.
Conv2D
,
64
,
192
,
kernel_size
=
5
,
padding
=
2
),
F
.
relu
,
LayerDesc
(
nn
.
MaxPool2D
,
kernel_size
=
2
,
stride
=
2
),
LayerDesc
(
nn
.
Conv2D
,
192
,
384
,
kernel_size
=
3
,
padding
=
1
),
F
.
relu
,
LayerDesc
(
nn
.
Conv2D
,
384
,
256
,
kernel_size
=
3
,
padding
=
1
),
F
.
relu
,
LayerDesc
(
nn
.
Conv2D
,
256
,
256
,
kernel_size
=
3
,
padding
=
1
),
F
.
relu
,
LayerDesc
(
nn
.
MaxPool2D
,
kernel_size
=
2
,
stride
=
2
),
lambda
x
:
x
.
flatten
(),
LayerDesc
(
nn
.
Linear
,
256
,
self
.
num_classes
),
# classifier
]
super
(
AlexNetPipeDesc
,
self
).
__init__
(
layers
=
decs
,
loss_fn
=
nn
.
CrossEntropyLoss
(),
**
kwargs
)
class
TestPipeLayerAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
self
.
model_parallel_size
=
2
strategy
.
hybrid_configs
=
{
"dp_degree"
:
1
,
"mp_degree"
:
1
,
"pp_degree"
:
self
.
model_parallel_size
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
self
.
hcg
=
fleet
.
get_hybrid_communicate_group
()
def
test_pipelayer_desc
(
self
):
pipe_model
=
AlexNetPipeDesc
(
num_stages
=
self
.
model_parallel_size
)
np
.
testing
.
assert_array_equal
(
len
(
pipe_model
.
parameters
()),
6
)
def
test_pipelayer_sequential
(
self
):
init_net
=
AlexNetPipe
()
pipe_model
=
PipelineLayer
(
layers
=
init_net
.
to_layers
(),
num_stages
=
self
.
model_parallel_size
,
loss_fn
=
nn
.
CrossEntropyLoss
())
stage_id
=
self
.
hcg
.
get_stage_id
()
init_parameters
=
init_net
.
parameters
()
pipe_parameters
=
pipe_model
.
parameters
()
part_number
=
len
(
init_parameters
)
//
2
if
stage_id
==
0
:
for
idx
in
range
(
part_number
):
param_a
=
init_parameters
[
idx
]
param_b
=
pipe_parameters
[
idx
]
np
.
testing
.
assert_array_equal
(
param_a
.
name
,
param_b
.
name
)
np
.
testing
.
assert_allclose
(
param_a
.
numpy
(),
param_b
.
numpy
())
elif
stage_id
==
1
:
for
idx
in
range
(
part_number
):
param_a
=
init_parameters
[
idx
+
part_number
]
param_b
=
pipe_parameters
[
idx
]
np
.
testing
.
assert_array_equal
(
param_a
.
name
,
param_b
.
name
)
np
.
testing
.
assert_allclose
(
param_a
.
numpy
(),
param_b
.
numpy
())
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_layer.py
0 → 100644
浏览文件 @
7ef1de67
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestHybridPipeParallel
(
TestMultipleGpus
):
def
test_hybrid_parallel_pp_layer
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_layer.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
7ef1de67
...
@@ -159,7 +159,7 @@ packages=['paddle',
...
@@ -159,7 +159,7 @@ packages=['paddle',
'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils',
'paddle.distributed.fleet.utils',
'paddle.distributed.fleet.meta_parallel',
'paddle.distributed.fleet.meta_parallel',
'paddle.distributed.fleet.meta_parallel.
mp_util
s',
'paddle.distributed.fleet.meta_parallel.
parallel_layer
s',
'paddle.framework',
'paddle.framework',
'paddle.jit',
'paddle.jit',
'paddle.jit.dy2static',
'paddle.jit.dy2static',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录