Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c809530e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c809530e
编写于
5月 17, 2021
作者:
S
ShenLiang
提交者:
GitHub
5月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[HybridParallel]Fix precision problem of model parallel (#32897)
* fix precision of mp * fix bug of seed * fix dp * print group
上级
906db719
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
151 addition
and
58 deletion
+151
-58
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+7
-0
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+4
-1
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+12
-3
python/paddle/distributed/fleet/base/topology.py
python/paddle/distributed/fleet/base/topology.py
+3
-3
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
...ptimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
+1
-1
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
...optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
+2
-2
python/paddle/distributed/fleet/meta_parallel/__init__.py
python/paddle/distributed/fleet/meta_parallel/__init__.py
+1
-1
python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/mp_layers.py
+97
-38
python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
...distributed/fleet/meta_parallel/parallel_layers/random.py
+9
-4
python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py
...paddle/distributed/fleet/meta_parallel/tensor_parallel.py
+3
-3
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
...on/paddle/distributed/fleet/utils/hybrid_parallel_util.py
+9
-1
python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py
...paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py
+1
-1
python/paddle/fluid/tests/unittests/new_group.py
python/paddle/fluid/tests/unittests/new_group.py
+1
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
c809530e
...
...
@@ -141,6 +141,7 @@ message PipelineConfig {
message
TensorParallelConfig
{
optional
int32
tensor_parallel_degree
=
1
[
default
=
1
];
optional
int32
tensor_init_seed
=
2
[
default
=
-
1
];
}
message
DistributedStrategy
{
...
...
python/paddle/distributed/collective.py
浏览文件 @
c809530e
...
...
@@ -99,6 +99,13 @@ class Group():
else
:
return
-
1
def
__repr__
(
self
):
debug_str
=
"rank: {}, nranks: {}, id: {}, ranks: "
.
format
(
self
.
rank
,
self
.
nranks
,
self
.
id
)
debug_str
+=
", "
.
join
(
map
(
str
,
self
.
ranks
))
debug_str
+=
". "
return
debug_str
_global_env
=
None
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
c809530e
...
...
@@ -949,6 +949,8 @@ class DistributedStrategy(object):
**Notes**:
**Detailed arguments for tensor_parallel_configs**
**tensor_parallel_degree**: degree of tensor parallel
**tensor_init_seed**: parameter initialization random seed
Examples:
...
...
@@ -957,7 +959,8 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4}
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
"tensor_init_seed": 123}
"""
return
get_msg_dict
(
self
.
strategy
.
tensor_parallel_configs
)
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
c809530e
...
...
@@ -17,6 +17,7 @@ import copy
import
warnings
import
paddle
import
os
import
numpy
as
np
from
paddle.fluid.framework
import
dygraph_only
from
paddle.fluid
import
compiler
from
.role_maker
import
UserDefinedRoleMaker
,
PaddleCloudRoleMaker
,
RoleMakerBase
...
...
@@ -28,7 +29,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator
from
paddle.fluid.dygraph
import
parallel_helper
from
.
import
topology
as
tp
from
.topology
import
ParallelMode
from
..meta_parallel
import
ModelParallel
from
..meta_parallel
import
TensorParallel
,
model_parallel_random_seed
from
..meta_parallel
import
PipelineParallel
from
..meta_optimizers
import
HybridParallelOptimizer
from
..meta_optimizers
import
HybridParallelGradScaler
...
...
@@ -279,6 +280,14 @@ class Fleet(object):
self
.
_hcg
=
tp
.
HybridCommunicateGroup
(
self
.
_topology
)
if
self
.
mp_degree
>
1
:
tensor_parallel_configs
=
self
.
_user_defined_strategy
.
tensor_parallel_configs
tensor_init_seed
=
tensor_parallel_configs
[
"tensor_init_seed"
]
if
tensor_init_seed
==
-
1
:
model_parallel_random_seed
()
else
:
model_parallel_random_seed
(
tensor_init_seed
)
def
get_hybrid_communicate_group
(
self
):
assert
self
.
_hcg
is
not
None
return
self
.
_hcg
...
...
@@ -829,8 +838,8 @@ class Fleet(object):
last_comm_group_size_MB
,
find_unused_parameters
=
self
.
_user_defined_strategy
.
find_unused_parameters
)
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
MODEL
_PARALLEL
:
distributed_model
=
Model
Parallel
(
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
TENSOR
_PARALLEL
:
distributed_model
=
Tensor
Parallel
(
model
,
self
.
_hcg
,
strategy
=
self
.
_user_defined_strategy
)
elif
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
PIPELINE_PARALLEL
:
distributed_model
=
PipelineParallel
(
...
...
python/paddle/distributed/fleet/base/topology.py
浏览文件 @
c809530e
...
...
@@ -28,7 +28,7 @@ _HYBRID_PARALLEL_GROUP = None
class
ParallelMode
(
object
):
DATA_PARALLEL
=
0
MODEL
_PARALLEL
=
1
TENSOR
_PARALLEL
=
1
PIPELINE_PARALLEL
=
2
...
...
@@ -155,12 +155,12 @@ class HybridCommunicateGroup(object):
_HYBRID_PARALLEL_GROUP
=
self
def
get_parallel_mode
(
self
):
# there are three modes : DataParallel /
Model
Parallel / PipelineParallel
# there are three modes : DataParallel /
Tensor
Parallel / PipelineParallel
if
self
.
_mp_degree
==
1
and
self
.
_pp_degree
==
1
:
return
ParallelMode
.
DATA_PARALLEL
elif
self
.
_mp_degree
>
1
and
self
.
_pp_degree
==
1
:
# initialize the seed
return
ParallelMode
.
MODEL
_PARALLEL
return
ParallelMode
.
TENSOR
_PARALLEL
elif
self
.
_pp_degree
>
1
:
return
ParallelMode
.
PIPELINE_PARALLEL
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py
浏览文件 @
c809530e
...
...
@@ -31,7 +31,7 @@ class HybridParallelGradScaler:
self
.
_scaler
=
scaler
self
.
_hcg
=
hcg
self
.
_is_mp
=
(
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
MODEL
_PARALLEL
)
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
TENSOR
_PARALLEL
)
def
scale
(
self
,
var
):
return
self
.
_scaler
.
scale
(
var
)
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
浏览文件 @
c809530e
...
...
@@ -90,12 +90,12 @@ class HybridParallelOptimizer:
self
.
_strategy
=
strategy
self
.
_hcg
=
hcg
self
.
_is_mp
=
(
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
MODEL
_PARALLEL
)
self
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
TENSOR
_PARALLEL
)
self
.
_need_dp
=
(
self
.
_hcg
.
get_data_parallel_world_size
()
>
1
)
if
isinstance
(
self
.
_inner_opt
.
_grad_clip
,
ClipGradByGlobalNorm
)
and
self
.
_is_mp
:
logger
.
warning
(
"using ClipGradByGlobalNorm in
Model
Parallel, the origin "
\
logger
.
warning
(
"using ClipGradByGlobalNorm in
Tensor
Parallel, the origin "
\
"optmizer'grad clip will be changed."
)
self
.
_inner_opt
.
_grad_clip
=
HybridParallelClipGrad
(
self
.
_inner_opt
.
_grad_clip
,
hcg
)
...
...
python/paddle/distributed/fleet/meta_parallel/__init__.py
浏览文件 @
c809530e
...
...
@@ -20,7 +20,7 @@ from .parallel_layers import PipelineLayer # noqa: F401
from
.parallel_layers
import
RNGStatesTracker
# noqa: F401
from
.parallel_layers
import
model_parallel_random_seed
# noqa: F401
from
.parallel_layers
import
get_rng_state_tracker
# noqa: F401
from
.
model_parallel
import
Model
Parallel
# noqa: F401
from
.
tensor_parallel
import
Tensor
Parallel
# noqa: F401
from
.pipeline_parallel
import
PipelineParallel
# noqa: F401
__all__
=
[]
python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py
浏览文件 @
c809530e
...
...
@@ -41,6 +41,7 @@ class VocabParallelEmbedding(Layer):
self
.
rank
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_rank
()
self
.
origin_num_embeddings
=
num_embeddings
self
.
is_mp
=
(
self
.
world_size
>
1
)
per_part_size
=
(
num_embeddings
+
self
.
world_size
-
1
)
//
self
.
world_size
...
...
@@ -50,16 +51,36 @@ class VocabParallelEmbedding(Layer):
per_part_size
+=
1
# make the last row as the padding index
self
.
per_part_size
=
per_part_size
self
.
embedding
=
paddle
.
nn
.
Embedding
(
per_part_size
,
embedding_dim
,
padding_idx
=
per_part_size
-
1
,
sparse
=
False
,
weight_attr
=
weight_attr
,
name
=
name
)
self
.
embedding
.
weight
.
is_distributed
=
True
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_size
=
[
per_part_size
,
embedding_dim
]
self
.
_weight_attr
=
weight_attr
self
.
_name
=
name
if
self
.
is_mp
:
with
get_rng_state_tracker
().
rng_state
():
self
.
weight
=
self
.
create_parameter
(
attr
=
self
.
_weight_attr
,
shape
=
self
.
_size
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
weight
[
per_part_size
-
1
]
=
0.0
self
.
weight
.
is_distributed
=
True
else
:
self
.
weight
=
self
.
create_parameter
(
attr
=
self
.
_weight_attr
,
shape
=
[
num_embeddings
,
embedding_dim
],
dtype
=
self
.
_dtype
,
is_bias
=
False
)
def
forward
(
self
,
x
):
if
not
self
.
is_mp
:
return
F
.
embedding
(
x
,
weight
=
self
.
weight
,
padding_idx
=
None
,
sparse
=
False
,
name
=
self
.
_name
)
origin_input_shape
=
x
.
shape
if
len
(
origin_input_shape
)
==
2
:
x
=
paddle
.
unsqueeze
(
x
,
axis
=-
1
)
...
...
@@ -72,13 +93,18 @@ class VocabParallelEmbedding(Layer):
if
len
(
origin_input_shape
)
==
2
:
x_shard
=
paddle
.
squeeze
(
x_shard
,
axis
=-
1
)
emb_out
=
self
.
embedding
(
x_shard
)
if
self
.
world_size
>
1
:
emb_out
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
emb_out
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
emb_out
=
F
.
embedding
(
x_shard
,
weight
=
self
.
weight
,
padding_idx
=
self
.
per_part_size
-
1
,
sparse
=
False
,
name
=
self
.
_name
)
emb_out
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
emb_out
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
return
emb_out
...
...
@@ -96,8 +122,9 @@ class ColumnParallelLinear(Layer):
)
self
.
world_size
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_world_size
(
)
self
.
_name
=
name
self
.
is_mp
=
(
self
.
world_size
>
1
)
self
.
name
=
name
self
.
gather_output
=
gather_output
assert
out_features
%
self
.
world_size
==
0
,
(
"Number of column of the weight for linear ({}) must be"
...
...
@@ -108,10 +135,20 @@ class ColumnParallelLinear(Layer):
self
.
_weight_attr
=
weight_attr
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
weight
=
self
.
create_parameter
(
shape
=
[
in_features
,
self
.
output_size_per_partition
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
)
if
self
.
is_mp
:
with
get_rng_state_tracker
().
rng_state
():
self
.
weight
=
self
.
create_parameter
(
shape
=
[
in_features
,
self
.
output_size_per_partition
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
else
:
self
.
weight
=
self
.
create_parameter
(
shape
=
[
in_features
,
self
.
output_size_per_partition
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
weight
.
is_distributed
=
True
if
has_bias
:
...
...
@@ -119,18 +156,24 @@ class ColumnParallelLinear(Layer):
self
.
bias
=
self
.
create_parameter
(
shape
=
[
self
.
output_size_per_partition
],
attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.0
),
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
,
is_bias
=
True
)
self
.
bias
.
is_distributed
=
True
else
:
self
.
bias
=
None
def
forward
(
self
,
x
):
# use inner api to process identity
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
x
,
group
=
self
.
model_parallel_group
)
if
self
.
is_mp
:
input_parallel
=
paddle
.
distributed
.
collective
.
_c_identity
(
x
,
group
=
self
.
model_parallel_group
)
else
:
input_parallel
=
x
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
self
.
bias
,
name
=
self
.
name
)
if
self
.
gather_output
:
input_parallel
,
self
.
weight
,
self
.
bias
,
name
=
self
.
_name
)
if
self
.
gather_output
and
self
.
is_mp
:
output
=
paddle
.
distributed
.
collective
.
_c_concat
(
output_parallel
,
nranks
=
self
.
world_size
,
...
...
@@ -155,7 +198,7 @@ class RowParallelLinear(Layer):
self
.
input_is_parallel
=
input_is_parallel
self
.
_weight_attr
=
weight_attr
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
name
=
name
self
.
_
name
=
name
self
.
model_parallel_group
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_group
(
)
...
...
@@ -163,6 +206,7 @@ class RowParallelLinear(Layer):
)
self
.
rank
=
tp
.
_HYBRID_PARALLEL_GROUP
.
get_model_parallel_rank
()
self
.
is_mp
=
(
self
.
world_size
>
1
)
assert
in_features
%
self
.
world_size
==
0
,
(
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})"
.
format
(
in_features
,
...
...
@@ -170,22 +214,33 @@ class RowParallelLinear(Layer):
self
.
input_size_per_partition
=
in_features
//
self
.
world_size
self
.
weight
=
self
.
create_parameter
(
shape
=
[
self
.
input_size_per_partition
,
self
.
out_features
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
)
if
self
.
is_mp
:
with
get_rng_state_tracker
().
rng_state
():
self
.
weight
=
self
.
create_parameter
(
shape
=
[
self
.
input_size_per_partition
,
self
.
out_features
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
else
:
self
.
weight
=
self
.
create_parameter
(
shape
=
[
self
.
input_size_per_partition
,
self
.
out_features
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
weight
.
is_distributed
=
True
if
has_bias
:
self
.
bias
=
self
.
create_parameter
(
shape
=
[
self
.
out_features
],
attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.0
),
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
,
is_bias
=
True
)
else
:
self
.
bias
=
None
def
forward
(
self
,
x
):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
or
(
not
self
.
is_mp
)
:
input_parallel
=
x
else
:
# split last dim
...
...
@@ -195,12 +250,16 @@ class RowParallelLinear(Layer):
nranks
=
self
.
world_size
,
group
=
self
.
model_parallel_group
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
name
=
self
.
name
)
output_
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
name
=
self
.
_name
)
if
self
.
is_mp
:
output_
=
paddle
.
distributed
.
collective
.
_mp_allreduce
(
output_parallel
,
group
=
self
.
model_parallel_group
,
use_calc_stream
=
True
,
use_model_parallel
=
True
)
else
:
output_
=
output_parallel
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
return
output
python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py
浏览文件 @
c809530e
...
...
@@ -14,6 +14,7 @@
import
paddle
import
contextlib
import
numpy
as
np
__all__
=
[]
...
...
@@ -65,14 +66,18 @@ def get_rng_state_tracker():
return
RNG_STATE_TRACKER
def
model_parallel_random_seed
(
seed
=
2048
):
def
model_parallel_random_seed
(
seed
=
None
):
import
paddle.distributed.fleet
as
fleet
hcg
=
fleet
.
get_hybrid_communicate_group
()
rank
=
hcg
.
get_model_parallel_rank
()
local_seed
=
seed
+
1024
+
rank
global_seed
=
seed
if
seed
:
global_seed
=
seed
local_seed
=
seed
*
1024
+
rank
*
100
else
:
global_seed
=
np
.
random
.
randint
(
0
,
655350
)
local_seed
=
np
.
random
.
randint
(
rank
*
10000
,
(
rank
+
1
)
*
10000
-
1
)
RNG_STATE_TRACKER
.
reset
()
paddle
.
seed
(
global_seed
)
RNG_STATE_TRACKER
.
add
(
MODEL_PARALLEL_RNG
,
local_seed
)
paddle
.
seed
(
global_seed
)
python/paddle/distributed/fleet/meta_parallel/
model
_parallel.py
→
python/paddle/distributed/fleet/meta_parallel/
tensor
_parallel.py
浏览文件 @
c809530e
...
...
@@ -22,15 +22,15 @@ from ..utils.log_util import logger
__all__
=
[]
class
Model
Parallel
(
MetaParallelBase
):
class
Tensor
Parallel
(
MetaParallelBase
):
def
__init__
(
self
,
layers
,
hcg
,
**
kwargs
):
super
(
Model
Parallel
,
self
).
__init__
(
layers
,
hcg
,
**
kwargs
)
super
(
Tensor
Parallel
,
self
).
__init__
(
layers
,
hcg
,
**
kwargs
)
def
_prepare_for_model
(
self
):
logger
.
info
(
"start broadcast mp parameters"
)
broadcast_mp_parameters
(
self
.
_layers
,
self
.
_hcg
)
logger
.
info
(
"start broadcast
m
p parameters"
)
logger
.
info
(
"start broadcast
d
p parameters"
)
broadcast_dp_parameters
(
self
.
_layers
,
self
.
_hcg
)
logger
.
info
(
"mp's parameters is ready"
)
...
...
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
浏览文件 @
c809530e
...
...
@@ -44,7 +44,15 @@ def _apply_collective_grads(parameters, comm_group):
for
coalesced_grad
,
_
,
_
in
coalesced_grads_and_vars
:
# need to div nranks
coalesced_grad
=
coalesced_grad
/
comm_group
.
nranks
div_factor
=
paddle
.
to_tensor
(
comm_group
.
nranks
,
dtype
=
coalesced_grad
.
dtype
)
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"elementwise_div"
,
inputs
=
{
'X'
:
coalesced_grad
,
'Y'
:
div_factor
},
outputs
=
{
'Out'
:
coalesced_grad
},
attrs
=
{
'axis'
:
-
1
})
paddle
.
distributed
.
all_reduce
(
coalesced_grad
,
group
=
comm_group
)
_split_tensors
(
coalesced_grads_and_vars
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py
浏览文件 @
c809530e
...
...
@@ -231,7 +231,7 @@ class TestDistTraning(unittest.TestCase):
# model_b
check_group
=
dist
.
new_group
(
list
(
range
(
self
.
model_parallel_size
)))
integral_w
=
[]
partial_w
=
model_a
.
embedding
.
embedding
.
weight
.
clone
().
detach
()
partial_w
=
model_a
.
embedding
.
weight
.
clone
().
detach
()
paddle
.
distributed
.
all_gather
(
integral_w
,
partial_w
,
group
=
check_group
)
result_w
=
[]
for
idx
in
range
(
len
(
integral_w
)):
...
...
python/paddle/fluid/tests/unittests/new_group.py
浏览文件 @
c809530e
...
...
@@ -27,6 +27,7 @@ class TestNewGroupAPI(object):
def
test_all
(
self
):
gp
=
paddle
.
distributed
.
new_group
([
0
,
1
])
print
(
"gp info:"
,
gp
)
print
(
"test new group api ok"
)
tmp
=
np
.
array
([
0
,
0
,
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录