Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d101334c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
d101334c
编写于
3月 28, 2022
作者:
C
caozhou
提交者:
GitHub
3月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Update reshard (#40865)
* fix code stype * update unitest
上级
86554d91
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
1309 addition
and
1221 deletion
+1309
-1221
python/paddle/distributed/auto_parallel/__init__.py
python/paddle/distributed/auto_parallel/__init__.py
+1
-1
python/paddle/distributed/auto_parallel/converter.py
python/paddle/distributed/auto_parallel/converter.py
+10
-10
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+7
-5
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+4
-6
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+1247
-1160
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+10
-10
python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
...paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
+0
-4
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
...le/fluid/tests/unittests/test_auto_parallel_cost_model.py
+4
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
...paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
+4
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+10
-10
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
...luid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
+4
-3
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
.../fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
+7
-5
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py
...luid/tests/unittests/test_auto_parallel_reshard_serial.py
+1
-1
未找到文件。
python/paddle/distributed/auto_parallel/__init__.py
浏览文件 @
d101334c
...
...
@@ -15,7 +15,7 @@
from
.interface
import
shard_tensor
# noqa: F401
from
.interface
import
shard_op
# noqa: F401
from
.process_mesh
import
ProcessMesh
from
.reshard
import
reshard
# noqa: F401
from
.reshard
import
Resharder
# noqa: F401
from
.cost_model
import
estimate_cost
__all__
=
[]
python/paddle/distributed/auto_parallel/converter.py
浏览文件 @
d101334c
...
...
@@ -235,19 +235,19 @@ class Converter(object):
@
staticmethod
def
merge_with_dist_attr
(
tensor_list
,
dist_attr
):
""" Merge tensor with distributed attribute """
from
.reshard
import
_compute_complete_shape
,
_compute_partition_index
from
.reshard
import
Resharder
dims_mapping
=
dist_attr
[
"dims_mapping"
]
process_shape
=
dist_attr
[
"process_shape"
]
process_group
=
dist_attr
[
"process_group"
]
# get the complete shape of the tensor
complete_shape
=
_compute_complete_shape
(
tensor_list
[
0
].
shape
,
process_shape
,
dims_mapping
)
complete_shape
=
Resharder
.
compute_complete_shape
(
tensor_list
[
0
].
shape
,
process_shape
,
dims_mapping
)
# merge the tensor with dist_attr
partition_tensor_list
=
[]
merged_partiton
=
[]
for
process
in
process_group
:
partition_index
=
_
compute_partition_index
(
partition_index
=
Resharder
.
compute_partition_index
(
process
,
complete_shape
,
dims_mapping
,
process_shape
,
process_group
)
index
=
process_group
.
index
(
process
)
...
...
@@ -302,7 +302,7 @@ class Converter(object):
_merge_tensor(partition_tensor_list, tensor, partition_index)
# partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
"""
from
.reshard
import
_compute_concat_info
from
.reshard
import
Resharder
if
len
(
partition_tensor_list
)
==
1
:
is_complete_data
=
True
...
...
@@ -318,7 +318,7 @@ class Converter(object):
else
:
i
=
0
while
i
<
len
(
partition_tensor_list
):
concat_axis
,
first_order
,
new_partition
=
_
compute_concat_info
(
concat_axis
,
first_order
,
new_partition
=
Resharder
.
compute_concat_info
(
partition_tensor_list
[
i
][
1
],
partition_index
)
if
concat_axis
!=
-
1
:
if
first_order
==
0
:
...
...
@@ -391,11 +391,11 @@ class Converter(object):
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]]
"""
from
.reshard
import
_compute_partition_index
from
.reshard
import
Resharder
split_indices_list
=
[]
for
process
in
process_group
:
partition_index
=
_
compute_partition_index
(
partition_index
=
Resharder
.
compute_partition_index
(
process
,
complete_shape
,
dims_mapping
,
process_shape
,
process_group
)
if
split_indices_list
:
...
...
@@ -437,9 +437,9 @@ class Converter(object):
process_shape, process_group)
# index: 2
"""
from
.reshard
import
_compute_partition_index
from
.reshard
import
Resharder
partition_index
=
_
compute_partition_index
(
partition_index
=
Resharder
.
compute_partition_index
(
rank_id
,
complete_shape
,
dims_mapping
,
process_shape
,
process_group
)
sliced_index
=
0
for
i
,
shape
in
enumerate
(
complete_shape
):
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
d101334c
...
...
@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger
from
.mapper
import
mapping
from
.cluster
import
Cluster
from
.reshard
import
reshard
from
.reshard
import
Resharder
from
.planner
import
Planner
from
.completion
import
Completer
from
.partitioner
import
Partitioner
...
...
@@ -187,8 +187,9 @@ class Engine:
# Do reshard process
set_grad_var_shape
(
dist_main_prog
,
dist_context
)
make_data_unshard
(
dist_main_prog
,
dist_startup_prog
,
dist_context
)
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# Apply post optimization passes
self
.
_apply_post_optimization
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_params_grads
)
...
...
@@ -199,8 +200,9 @@ class Engine:
serial_main_program
,
serial_startup_program
,
[])
# Do reshard process
make_data_unshard
(
dist_main_prog
,
dist_startup_prog
,
dist_context
)
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_context
,
[],
1
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_context
,
[],
1
)
resharder
.
reshard
()
# clone program for test
if
mode
!=
'train'
:
...
...
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
d101334c
...
...
@@ -42,7 +42,7 @@ from .utils import make_data_unshard
from
.utils
import
set_grad_var_shape
from
.utils
import
print_program_with_dist_attr
from
.utils
import
SerialProgramInfo
from
.reshard
import
reshard
,
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
.reshard
import
Resharder
from
.cluster
import
Cluster
from
.mapper
import
mapping
from
.dist_op
import
DistributedOperator
...
...
@@ -213,17 +213,15 @@ class AutoParallelizer:
make_data_unshard
(
dist_main_prog
,
dist_startup_prog
,
self
.
_dist_context
)
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank
,
self
.
_dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank
,
self
.
_dist_context
,
dist_params_grads
)
resharder
.
reshard
()
self
.
_apply_post_optimization_passes
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_params_grads
)
g_process_group_map
=
None
if
not
relaunch_phase
:
g_process_group_map
=
copy
.
deepcopy
(
_g_process_group_map
)
HAS_SENT
.
clear
()
HAS_RECV
.
clear
()
HAS_ALLGATHER
.
clear
()
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
for
process_mesh
in
dist_context
.
_process_meshes
:
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
d101334c
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
d101334c
...
...
@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
def
_merge_parameter_with_dist_attr
(
param_list
,
dist_attr
):
""" Merge parameter with distributed attribute """
from
.reshard
import
_compute_complete_shape
,
_compute_partition_index
from
.reshard
import
Resharder
dims_mapping
=
dist_attr
[
"dims_mapping"
]
process_shape
=
dist_attr
[
"process_shape"
]
process_group
=
dist_attr
[
"process_group"
]
# get the complete shape of the parameter
complete_shape
=
_compute_complete_shape
(
param_list
[
0
].
shape
,
process_shape
,
dims_mapping
)
complete_shape
=
Resharder
.
compute_complete_shape
(
param_list
[
0
].
shape
,
process_shape
,
dims_mapping
)
# merge the parameter with dist_attr
partition_param_list
=
[]
merged_partiton
=
[]
for
process
in
process_group
:
partition_index
=
_
compute_partition_index
(
partition_index
=
Resharder
.
compute_partition_index
(
process
,
complete_shape
,
dims_mapping
,
process_shape
,
process_group
)
index
=
process_group
.
index
(
process
)
if
partition_index
not
in
merged_partiton
:
...
...
@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
_merge_parameter(partition_param_list, param, partition_index)
# partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
"""
from
.reshard
import
_compute_concat_info
from
.reshard
import
Resharder
if
len
(
partition_param_list
)
==
1
:
is_complete_data
=
True
...
...
@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
else
:
i
=
0
while
i
<
len
(
partition_param_list
):
concat_axis
,
first_order
,
new_partition
=
_
compute_concat_info
(
concat_axis
,
first_order
,
new_partition
=
Resharder
.
compute_concat_info
(
partition_param_list
[
i
][
1
],
partition_index
)
if
concat_axis
!=
-
1
:
if
first_order
==
0
:
...
...
@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_shape, process_group)
# index: 2
"""
from
.reshard
import
_compute_partition_index
from
.reshard
import
Resharder
partition_index
=
_
compute_partition_index
(
partition_index
=
Resharder
.
compute_partition_index
(
rank
,
complete_shape
,
dims_mapping
,
process_shape
,
process_group
)
sliced_param_index
=
0
for
i
,
shape
in
enumerate
(
complete_shape
):
...
...
@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]]
"""
from
.reshard
import
_compute_partition_index
from
.reshard
import
Resharder
split_indices_list
=
[]
for
process
in
process_group
:
partition_index
=
_
compute_partition_index
(
partition_index
=
Resharder
.
compute_partition_index
(
process
,
complete_shape
,
dims_mapping
,
process_shape
,
process_group
)
if
split_indices_list
:
for
dim
in
range
(
len
(
partition_index
)):
...
...
python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
浏览文件 @
d101334c
...
...
@@ -31,7 +31,6 @@ from paddle.distributed import fleet
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle.distributed.auto_parallel.utils
import
save_distributed_checkpoint
,
load_distributed_checkpoint
,
load_checkpoint_into_program
from
paddle.distributed.auto_parallel.utils
import
get_dist_attr
,
merge_and_slice_parameter
,
load_parameter_into_program
from
paddle.distributed.auto_parallel.reshard
import
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
paddle.distributed.auto_parallel.dist_context
import
set_default_distributed_context
paddle
.
enable_static
()
...
...
@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase):
paddle
.
seed
(
2021
)
random
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
HAS_SENT
.
clear
()
HAS_RECV
.
clear
()
HAS_ALLGATHER
.
clear
()
def
tearDown
(
self
):
os
.
remove
(
"./model_state_rank{}.pdmodel"
.
format
(
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
浏览文件 @
d101334c
...
...
@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from
paddle.distributed
import
fleet
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.parallelizer
import
AutoParallelizer
from
paddle.distributed.auto_parallel.reshard
import
reshard
from
paddle.distributed.auto_parallel.reshard
import
Resharder
from
paddle.distributed.auto_parallel.cost_model
import
estimate_cost
import
paddle.fluid.core
as
core
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
...
...
@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase):
dist_context
=
DistributedContext
()
distributed_program
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
reshard
(
distributed_program
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
distributed_program
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
dist_program
.
append
(
distributed_program
)
cluster
=
None
cost
=
estimate_cost
(
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
浏览文件 @
d101334c
...
...
@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer
from
paddle.distributed.auto_parallel.parallelizer
import
AutoParallelizer
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.reshard
import
reshard
from
paddle.distributed.auto_parallel.reshard
import
Resharder
from
paddle.distributed.auto_parallel.process_group
import
get_all_process_groups
from
paddle.distributed.auto_parallel.process_group
import
new_process_group
from
paddle.distributed.auto_parallel.cluster
import
Cluster
...
...
@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops
=
parallelizer
.
_apply_optimize
(
dist_train_program
,
dist_startup_prog
,
dist_params_grads
)
reshard
(
dist_train_program
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_train_program
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
return
dist_train_program
,
dist_startup_prog
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
d101334c
...
...
@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from
paddle.distributed
import
fleet
from
paddle.distributed.auto_parallel.parallelizer
import
AutoParallelizer
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.reshard
import
reshard
,
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
paddle.distributed.auto_parallel.reshard
import
Resharder
from
paddle.distributed.auto_parallel.process_group
import
_g_process_group_map
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
...
...
@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase):
train_program
,
startup_program
,
dist_context
,
rank_id
)
for
key
in
list
(
_g_process_group_map
.
keys
()):
del
_g_process_group_map
[
key
]
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# check send and recv result
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
...
...
@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase):
self
.
assertTrue
(
check_initialization
(
dist_startup_prog
,
rank_id
))
def
test_mlp_pp_diff_process_mesh
(
self
):
HAS_SENT
.
clear
()
HAS_RECV
.
clear
()
HAS_ALLGATHER
.
clear
()
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
dist_context
=
DistributedContext
()
...
...
@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase):
train_program
,
startup_program
,
dist_context
,
rank_id
,
True
)
for
key
in
list
(
_g_process_group_map
.
keys
()):
del
_g_process_group_map
[
key
]
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
print_program_with_dist_attr
(
dist_main_prog
,
dist_context
)
# check send and recv result
...
...
@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id
=
0
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# send and recv should not exist in dp scene.
self
.
assertFalse
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py
浏览文件 @
d101334c
...
...
@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from
paddle.distributed
import
fleet
from
paddle.distributed.auto_parallel.parallelizer
import
AutoParallelizer
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.reshard
import
reshard
from
paddle.distributed.auto_parallel.reshard
import
Resharder
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
paddle
.
enable_static
()
...
...
@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id
=
2
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py
浏览文件 @
d101334c
...
...
@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from
paddle.distributed
import
fleet
from
paddle.distributed.auto_parallel.parallelizer
import
AutoParallelizer
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.reshard
import
reshard
from
paddle.distributed.auto_parallel.reshard
import
Resharder
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
paddle
.
enable_static
()
...
...
@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id
=
2
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# check send and recv result
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
...
...
@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context
.
block_state
.
parse_forward_blocks
(
complete_train_program
)
partitioned_main_prog
,
partitioned_startup_prog
,
partitioned_params_grads
=
partitioner
.
partition
(
complete_train_program
,
startup_program
,
[])
reshard
(
partitioned_main_prog
,
partitioned_startup_prog
,
rank_id
,
dist_context
,
partitioned_params_grads
)
resharder
=
Resharder
(
partitioned_main_prog
,
partitioned_startup_prog
,
rank_id
,
dist_context
,
partitioned_params_grads
)
resharder
.
reshard
()
# the x should not be slice
self
.
assertTrue
(
check_allgather
(
partitioned_main_prog
))
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py
浏览文件 @
d101334c
...
...
@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto
from
paddle.distributed.auto_parallel.dist_context
import
get_default_distributed_context
from
paddle.distributed
import
fleet
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.reshard
import
reshard
from
paddle.distributed.auto_parallel.reshard
import
Resharder
from
paddle.distributed.auto_parallel.process_group
import
new_process_group
paddle
.
enable_static
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录