Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ea465fa5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ea465fa5
编写于
4月 26, 2021
作者:
S
ShenLiang
提交者:
GitHub
4月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[HybridParallel]Fix model parallel bug by using C++ op (#32536)
* fix model parallel * rm parallel_help.py * add embedding
上级
40e51b25
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
220 addition
and
182 deletion
+220
-182
paddle/fluid/operators/collective/c_allreduce_max_op.cc
paddle/fluid/operators/collective/c_allreduce_max_op.cc
+7
-2
paddle/fluid/operators/collective/c_allreduce_min_op.cc
paddle/fluid/operators/collective/c_allreduce_min_op.cc
+7
-2
paddle/fluid/operators/collective/c_allreduce_prod_op.cc
paddle/fluid/operators/collective/c_allreduce_prod_op.cc
+7
-2
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
+3
-1
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+0
-4
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+112
-31
python/paddle/distributed/fleet/meta_parallel/parallel_layers/layers_help.py
...ibuted/fleet/meta_parallel/parallel_layers/layers_help.py
+0
-116
python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/mp_layers.py
+25
-7
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
...on/paddle/distributed/fleet/utils/hybrid_parallel_util.py
+27
-14
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py
.../tests/unittests/test_parallel_dygraph_hybrid_parallel.py
+0
-3
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mp_layers.py
.../fluid/tests/unittests/test_parallel_dygraph_mp_layers.py
+29
-0
未找到文件。
paddle/fluid/operators/collective/c_allreduce_max_op.cc
浏览文件 @
ea465fa5
...
...
@@ -37,14 +37,19 @@ class CAllReduceMaxOpMaker : public CAllReduceOpMaker {
std
::
string
GetName
()
const
override
{
return
"Max"
;
}
};
DECLARE_INPLACE_OP_INFERER
(
AllreduceMaxInplaceInferer
,
{
"X"
,
"Out"
});
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
c_allreduce_max
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceMaxOpMaker
);
REGISTER_OPERATOR
(
c_allreduce_max
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
AllreduceMaxInplaceInferer
)
REGISTER_OP_CPU_KERNEL
(
c_allreduce_max
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedMax
,
float
>
,
...
...
paddle/fluid/operators/collective/c_allreduce_min_op.cc
浏览文件 @
ea465fa5
...
...
@@ -37,14 +37,19 @@ class CAllReduceMinOpMaker : public CAllReduceOpMaker {
std
::
string
GetName
()
const
override
{
return
"Min"
;
}
};
DECLARE_INPLACE_OP_INFERER
(
AllreduceMinInplaceInferer
,
{
"X"
,
"Out"
});
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
c_allreduce_min
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceMinOpMaker
);
REGISTER_OPERATOR
(
c_allreduce_min
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceMinOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
AllreduceMinInplaceInferer
)
REGISTER_OP_CPU_KERNEL
(
c_allreduce_min
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedMin
,
float
>
,
...
...
paddle/fluid/operators/collective/c_allreduce_prod_op.cc
浏览文件 @
ea465fa5
...
...
@@ -37,14 +37,19 @@ class CAllReduceProdOpMaker : public CAllReduceOpMaker {
std
::
string
GetName
()
const
override
{
return
"Prod"
;
}
};
DECLARE_INPLACE_OP_INFERER
(
AllreduceProdInplaceInferer
,
{
"X"
,
"Out"
});
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
c_allreduce_prod
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceProdOpMaker
);
REGISTER_OPERATOR
(
c_allreduce_prod
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceProdOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
AllreduceProdInplaceInferer
)
REGISTER_OP_CPU_KERNEL
(
c_allreduce_prod
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedProd
,
float
>
,
...
...
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
浏览文件 @
ea465fa5
...
...
@@ -54,6 +54,8 @@ class CAllReduceSumOpMaker : public CAllReduceOpMaker {
std
::
string
GetName
()
const
override
{
return
"Sum"
;
}
};
DECLARE_INPLACE_OP_INFERER
(
AllreduceSumInplaceInferer
,
{
"X"
,
"Out"
});
}
// namespace operators
}
// namespace paddle
...
...
@@ -63,7 +65,7 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR
(
c_allreduce_sum
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceSumOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
CAllReduceSumOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
CAllReduceSumOpMaker
);
ops
::
CAllReduceSumOpMaker
,
ops
::
AllreduceSumInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
c_allreduce_sum
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedSum
,
float
>
,
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
ea465fa5
...
...
@@ -127,10 +127,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{
"c_broadcast"
,
{
"Out"
}},
{
"c_sync_calc_stream"
,
{
"Out"
}},
{
"c_sync_comm_stream"
,
{
"Out"
}},
{
"c_allreduce_sum"
,
{
"Out"
}},
{
"c_allreduce_max"
,
{
"Out"
}},
{
"c_allreduce_min"
,
{
"Out"
}},
{
"c_allreduce_prod"
,
{
"Out"
}},
{
"c_reduce_sum"
,
{
"Out"
}},
{
"c_reduce_max"
,
{
"Out"
}},
{
"c_reduce_min"
,
{
"Out"
}},
...
...
python/paddle/distributed/collective.py
浏览文件 @
ea465fa5
...
...
@@ -397,23 +397,22 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
return
ring_id
=
0
if
group
is
None
else
group
.
id
if
in_dygraph_mode
():
if
op
==
ReduceOp
.
SUM
:
return
core
.
ops
.
c_allreduce_sum
(
tensor
,
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
return
core
.
ops
.
c_allreduce_sum
_
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
elif
op
==
ReduceOp
.
MAX
:
return
core
.
ops
.
c_allreduce_max
(
tensor
,
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
return
core
.
ops
.
c_allreduce_max
_
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
elif
op
==
ReduceOp
.
MIN
:
return
core
.
ops
.
c_allreduce_min
(
tensor
,
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
return
core
.
ops
.
c_allreduce_min
_
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
elif
op
==
ReduceOp
.
PROD
:
return
core
.
ops
.
c_allreduce_prod
(
tensor
,
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
return
core
.
ops
.
c_allreduce_prod_
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
)
else
:
raise
ValueError
(
"Unknown parameter: {}."
.
format
(
op
))
return
out
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
...
...
@@ -692,7 +691,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
})
def
_c_identity
(
tensor
,
group
=
0
):
def
_c_identity
(
tensor
,
group
=
None
):
"""
Return a copy of the tensor, mainly used with model parallel.
...
...
@@ -704,30 +703,76 @@ def _c_identity(tensor, group=0):
Returns:
Tensor.
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
if
in_dygraph_mode
():
return
core
.
ops
.
c_identity
(
tensor
,
'use_calc_stream'
,
True
,
'ring_id'
,
ring_id
,
'use_model_parallel'
,
True
)
op_type
=
'c_identity'
helper
=
LayerHelper
(
op_type
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
tensor
.
dtype
)
if
in_dygraph_mode
():
return
core
.
ops
.
c_identity
(
out
,
tensor
,
'use_calc_stream'
,
True
,
'ring_id'
,
group
,
'use_model_parallel'
,
True
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
if
not
isinstance
(
group
,
int
):
raise
ValueError
(
"The type of 'group' for _c_identity should be int."
)
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
tensor
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
})
return
out
def
_c_concat
(
tensor
,
nranks
,
group
=
None
):
"""
Return allgather of the tensor, mainly used with model parallel.
Args:
tensor (Tensor): The input Tensor. Its data type
should be float16, float32, float64, int32 or int64.
group (int): The id of the process group to work on.
Returns:
Tensor.
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
if
in_dygraph_mode
():
return
core
.
ops
.
c_concat
(
tensor
,
'ring_id'
,
ring_id
,
'use_calc_stream'
,
True
,
'nranks'
,
nranks
,
'use_model_parallel'
,
True
)
op_type
=
'c_concat'
helper
=
LayerHelper
(
op_type
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
tensor
.
dtype
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_concat'
)
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
tensor
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
group
,
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'nranks'
:
nranks
})
return
out
def
_c_split
(
tensor
,
rank
,
nranks
,
group
=
0
):
def
_c_split
(
tensor
,
rank
,
nranks
,
group
=
None
):
"""
Split tensor evenly among all members, mainly used with model parallel.
...
...
@@ -740,23 +785,29 @@ def _c_split(tensor, rank, nranks, group=0):
Returns:
Tensor.
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
if
in_dygraph_mode
():
return
core
.
ops
.
c_split
(
tensor
,
'use_calc_stream'
,
True
,
'ring_id'
,
ring_id
,
'rank'
,
rank
,
'nranks'
,
nranks
,
'use_model_parallel'
,
True
)
op_type
=
'c_split'
helper
=
LayerHelper
(
op_type
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
tensor
.
dtype
)
if
in_dygraph_mode
():
return
core
.
ops
.
c_split
(
out
,
tensor
,
'use_calc_stream'
,
True
,
'ring_id'
,
group
,
'rank'
,
rank
,
'use_model_parallel'
,
True
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_split'
)
if
not
isinstance
(
group
,
int
):
raise
ValueError
(
"The type of 'group' for _identity should be int."
)
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
tensor
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
group
,
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
'rank'
:
rank
,
'nranks'
:
nranks
,
...
...
@@ -765,6 +816,28 @@ def _c_split(tensor, rank, nranks, group=0):
return
out
def
_mp_allreduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
use_calc_stream
=
True
,
use_model_parallel
=
True
):
"""[it is same as allreduce above, but it suuports model parallel. And it support inplace startegy]
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
if
in_dygraph_mode
():
if
op
==
ReduceOp
.
SUM
:
return
core
.
ops
.
c_allreduce_sum_
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
"use_model_parallel"
,
use_model_parallel
)
else
:
raise
ValueError
(
"Unknown parameter: {}."
.
format
(
op
))
else
:
raise
NotImplementedError
(
"No support _mp_allreduce in dygraph mode."
)
def
barrier
(
group
=
None
):
"""
...
...
@@ -816,10 +889,14 @@ def _parallel_linear(x,
nranks
,
split_tensor
,
name
,
group
=
0
):
group
=
None
):
"""
Parallel Linear
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
if
axis
==
0
:
if
split_tensor
:
x
=
_c_split
(
x
,
inner_rank
,
nranks
,
group
=
group
)
...
...
@@ -858,7 +935,7 @@ def _parallel_linear(x,
inputs
=
{
'X'
:
linear_out
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
group
,
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
})
...
...
@@ -868,7 +945,7 @@ def _parallel_linear(x,
inputs
=
{
'X'
:
linear_out
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
group
,
'ring_id'
:
ring_id
,
'nranks'
:
nranks
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
...
...
@@ -883,10 +960,14 @@ def _parallel_embedding(x,
inner_rank
,
num_partitions
,
name
,
group
=
0
):
group
=
None
):
"""
Parallel Embedding
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
origin_num_embeddings
=
origin_size
[
0
]
embedding
=
paddle
.
nn
.
Embedding
(
per_part_embeddings
,
...
...
@@ -924,7 +1005,7 @@ def _parallel_embedding(x,
inputs
=
{
'X'
:
emb_out
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
group
,
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
})
...
...
@@ -1050,7 +1131,7 @@ def split(x,
inner_rank
,
num_partitions
,
name
,
group
=
0
)
group
=
None
)
return
emb_out
else
:
should_split
=
False
...
...
@@ -1086,5 +1167,5 @@ def split(x,
num_partitions
,
should_split
,
name
=
name
,
group
=
0
)
group
=
None
)
return
linear_out
python/paddle/distributed/fleet/meta_parallel/parallel_layers/layers_help.py
已删除
100644 → 0
浏览文件 @
40e51b25
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.autograd
import
PyLayer
from
...base
import
topology
as
tp
import
paddle
# Follow this paper to achieve the file:
# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter
# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)
def
mp_reduce
(
x
):
if
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_world_size
()
==
1
:
return
x
paddle
.
distributed
.
all_reduce
(
x
,
group
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_group
())
return
x
def
mp_split
(
x
):
world_size
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_world_size
()
if
world_size
==
1
:
return
x
rank
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_rank
()
last_dim
=
len
(
x
.
shape
)
-
1
input_list
=
paddle
.
split
(
x
,
num_or_sections
=
world_size
,
axis
=
last_dim
)
output
=
input_list
[
rank
]
return
output
def
mp_gather
(
x
):
world_size
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_world_size
()
if
world_size
==
1
:
return
x
output
=
[]
paddle
.
distributed
.
all_gather
(
output
,
x
,
group
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_group
())
output
=
paddle
.
concat
(
output
,
axis
=
len
(
x
.
shape
)
-
1
)
return
output
class
_IdentityInModelParallel
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
x
):
return
x
@
staticmethod
def
backward
(
ctx
,
dx
):
return
mp_reduce
(
dx
)
class
_ReduceInModelParallel
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
x
):
return
mp_reduce
(
x
)
@
staticmethod
def
backward
(
ctx
,
dx
):
return
dx
class
_ScatterInModelParallel
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
x
):
return
mp_split
(
x
)
@
staticmethod
def
backward
(
ctx
,
dx
):
return
mp_gather
(
dx
)
class
_GatherInModelParallel
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
x
):
return
mp_gather
(
x
)
@
staticmethod
def
backward
(
ctx
,
dx
):
return
mp_split
(
dx
)
def
identity_in_model_parallel
(
x
):
return
_IdentityInModelParallel
.
apply
(
x
)
def
reduce_in_model_parallel
(
x
):
return
_ReduceInModelParallel
.
apply
(
x
)
def
scatter_in_model_parallel
(
x
):
return
_ScatterInModelParallel
.
apply
(
x
)
def
gather_in_model_parallel
(
x
):
return
_GatherInModelParallel
.
apply
(
x
)
python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
浏览文件 @
ea465fa5
...
...
@@ -18,7 +18,6 @@ from .random import get_rng_state_tracker
from
paddle.nn
import
functional
as
F
from
paddle
import
framework
from
...base
import
topology
as
tp
from
.layers_help
import
identity_in_model_parallel
,
gather_in_model_parallel
,
reduce_in_model_parallel
,
scatter_in_model_parallel
__all__
=
[
'VocabParallelEmbedding'
,
'ColumnParallelLinear'
,
'RowParallelLinear'
...
...
@@ -75,8 +74,13 @@ class VocabParallelEmbedding(Layer):
if
len
(
origin_input_shape
)
==
2
:
x_shard
=
paddle
.
squeeze
(
x_shard
,
axis
=-
1
)
emb_out_
=
self
.
embedding
(
x_shard
)
emb_out
=
reduce_in_model_parallel
(
emb_out_
)
emb_out
=
self
.
embedding
(
x_shard
)
if
self
.
world_size
>
1
:
emb_out
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
emb_out
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
return
emb_out
...
...
@@ -123,11 +127,16 @@ class ColumnParallelLinear(Layer):
self
.
bias
=
None
def
forward
(
self
,
x
):
input_parallel
=
identity_in_model_parallel
(
x
)
# use inner api to process identity
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
x
,
group
=
self
.
model_parallel_group
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
self
.
bias
,
name
=
self
.
name
)
if
self
.
gather_output
:
output
=
gather_in_model_parallel
(
output_parallel
)
output
=
paddle
.
distributed
.
collective
.
_c_concat
(
output_parallel
,
nranks
=
self
.
world_size
,
group
=
self
.
model_parallel_group
)
else
:
output
=
output_parallel
return
output
...
...
@@ -182,9 +191,18 @@ class RowParallelLinear(Layer):
input_parallel
=
x
else
:
# split last dim
input_parallel
=
scatter_in_model_parallel
(
x
)
input_parallel
=
paddle
.
distributed
.
collective
.
_c_split
(
x
,
rank
=
self
.
rank
,
nranks
=
self
.
world_size
,
group
=
self
.
model_parallel_group
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
name
=
self
.
name
)
output_
=
reduce_in_model_parallel
(
output_parallel
)
output_
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
return
output
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
浏览文件 @
ea465fa5
...
...
@@ -48,29 +48,42 @@ def _apply_collective_grads(parameters, comm_group):
_split_tensors
(
coalesced_grads_and_vars
)
def
broadcast_input_data
(
hcg
,
*
inputs
,
**
kwargs
):
def
_broadcast_data_help
(
data
,
shape
,
dtype
,
hcg
):
model_parallel_group
=
hcg
.
get_model_parallel_group
()
src_rank
=
hcg
.
get_model_parallel_group_src_rank
()
mp_rank
=
hcg
.
get_model_parallel_rank
()
for
input_
in
inputs
:
if
isinstance
(
input_
,
core
.
VarBase
):
with
framework
.
no_grad
():
shape_gpu
=
paddle
.
to_tensor
(
shape
,
dtype
=
"int32"
)
paddle
.
distributed
.
broadcast
(
input_
,
shape_gpu
,
src
=
src_rank
,
group
=
model_parallel_group
,
use_calc_stream
=
True
)
if
mp_rank
!=
0
:
input_data
=
paddle
.
zeros
(
shape_gpu
,
dtype
=
dtype
)
else
:
logger
.
error
(
"it doesn't support data type {}"
.
format
(
type
(
input_
)))
input_data
=
data
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
core
.
VarBase
):
with
framework
.
no_grad
():
paddle
.
distributed
.
broadcast
(
v
,
input_data
,
src
=
src_rank
,
group
=
model_parallel_group
,
use_calc_stream
=
True
)
def
broadcast_input_data
(
hcg
,
*
inputs
,
**
kwargs
):
for
v
in
inputs
:
if
isinstance
(
v
,
core
.
VarBase
):
with
framework
.
no_grad
():
_broadcast_data_help
(
v
,
v
.
shape
,
v
.
dtype
,
hcg
)
else
:
logger
.
error
(
"it doesn't support data type {}"
.
format
(
type
(
v
)))
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
core
.
VarBase
):
with
framework
.
no_grad
():
_broadcast_data_help
(
v
,
v
.
shape
,
v
.
dtype
,
hcg
)
kwargs
[
k
]
=
v
else
:
logger
.
error
(
"it doesn't support data type {}"
.
format
(
type
(
v
)))
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
ea465fa5
...
...
@@ -23,6 +23,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_layer
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
#remove distribute unittests.
list
(
APPEND MIXED_DIST_TEST_OPS test_dgc_op
)
...
...
@@ -175,6 +176,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_layer
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers
)
LIST
(
REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision
)
LIST
(
REMOVE_ITEM TEST_OPS test_fleet_base_single
)
LIST
(
REMOVE_ITEM TEST_OPS test_dygraph_recompute
)
...
...
@@ -861,6 +863,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_pipeline_layer PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py
浏览文件 @
ea465fa5
...
...
@@ -21,9 +21,6 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class
TestHybridParallel
(
TestMultipleGpus
):
def
test_hybrid_parallel_mp_layers
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_layers.py'
)
def
test_hybrid_parallel_mp_random
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_random.py'
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mp_layers.py
0 → 100644
浏览文件 @
ea465fa5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestModelParallelLayer
(
TestMultipleGpus
):
def
test_hybrid_parallel_mp_layer
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_layers.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录