未验证 提交 d101334c 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard (#40865)

* fix code stype

* update unitest
上级 86554d91
......@@ -15,7 +15,7 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh
from .reshard import reshard # noqa: F401
from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost
__all__ = []
......@@ -235,19 +235,19 @@ class Converter(object):
@staticmethod
def merge_with_dist_attr(tensor_list, dist_attr):
""" Merge tensor with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index
from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"]
# get the complete shape of the tensor
complete_shape = _compute_complete_shape(tensor_list[0].shape,
process_shape, dims_mapping)
complete_shape = Resharder.compute_complete_shape(
tensor_list[0].shape, process_shape, dims_mapping)
# merge the tensor with dist_attr
partition_tensor_list = []
merged_partiton = []
for process in process_group:
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape,
process_group)
index = process_group.index(process)
......@@ -302,7 +302,7 @@ class Converter(object):
_merge_tensor(partition_tensor_list, tensor, partition_index)
# partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
"""
from .reshard import _compute_concat_info
from .reshard import Resharder
if len(partition_tensor_list) == 1:
is_complete_data = True
......@@ -318,7 +318,7 @@ class Converter(object):
else:
i = 0
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info(
concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i][1], partition_index)
if concat_axis != -1:
if first_order == 0:
......@@ -391,11 +391,11 @@ class Converter(object):
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]]
"""
from .reshard import _compute_partition_index
from .reshard import Resharder
split_indices_list = []
for process in process_group:
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape,
process_group)
if split_indices_list:
......@@ -437,9 +437,9 @@ class Converter(object):
process_shape, process_group)
# index: 2
"""
from .reshard import _compute_partition_index
from .reshard import Resharder
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
rank_id, complete_shape, dims_mapping, process_shape, process_group)
sliced_index = 0
for i, shape in enumerate(complete_shape):
......
......@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger
from .mapper import mapping
from .cluster import Cluster
from .reshard import reshard
from .reshard import Resharder
from .planner import Planner
from .completion import Completer
from .partitioner import Partitioner
......@@ -187,8 +187,9 @@ class Engine:
# Do reshard process
set_grad_var_shape(dist_main_prog, dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_context, dist_params_grads)
resharder.reshard()
# Apply post optimization passes
self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
......@@ -199,8 +200,9 @@ class Engine:
serial_main_program, serial_startup_program, [])
# Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, [],
1)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_context, [], 1)
resharder.reshard()
# clone program for test
if mode != 'train':
......
......@@ -42,7 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .reshard import Resharder
from .cluster import Cluster
from .mapper import mapping
from .dist_op import DistributedOperator
......@@ -213,17 +213,15 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads)
resharder.reshard()
self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
g_process_group_map = None
if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes:
......
......@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
def _merge_parameter_with_dist_attr(param_list, dist_attr):
""" Merge parameter with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index
from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"]
# get the complete shape of the parameter
complete_shape = _compute_complete_shape(param_list[0].shape, process_shape,
dims_mapping)
complete_shape = Resharder.compute_complete_shape(
param_list[0].shape, process_shape, dims_mapping)
# merge the parameter with dist_attr
partition_param_list = []
merged_partiton = []
for process in process_group:
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group)
index = process_group.index(process)
if partition_index not in merged_partiton:
......@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
_merge_parameter(partition_param_list, param, partition_index)
# partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
"""
from .reshard import _compute_concat_info
from .reshard import Resharder
if len(partition_param_list) == 1:
is_complete_data = True
......@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
else:
i = 0
while i < len(partition_param_list):
concat_axis, first_order, new_partition = _compute_concat_info(
concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_param_list[i][1], partition_index)
if concat_axis != -1:
if first_order == 0:
......@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_shape, process_group)
# index: 2
"""
from .reshard import _compute_partition_index
from .reshard import Resharder
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
rank, complete_shape, dims_mapping, process_shape, process_group)
sliced_param_index = 0
for i, shape in enumerate(complete_shape):
......@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]]
"""
from .reshard import _compute_partition_index
from .reshard import Resharder
split_indices_list = []
for process in process_group:
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group)
if split_indices_list:
for dim in range(len(partition_index)):
......
......@@ -31,7 +31,6 @@ from paddle.distributed import fleet
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program
from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program
from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static()
......@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase):
paddle.seed(2021)
random.seed(2021)
np.random.seed(2021)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
def tearDown(self):
os.remove("./model_state_rank{}.pdmodel".format(
......
......@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
......@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase):
dist_context = DistributedContext()
distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(distributed_program, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder = Resharder(distributed_program, dist_startup_prog,
rank_id, dist_context, dist_params_grads)
resharder.reshard()
dist_program.append(distributed_program)
cluster = None
cost = estimate_cost(
......
......@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import get_all_process_groups
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.cluster import Cluster
......@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
resharder = Resharder(dist_train_program, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
return dist_train_program, dist_startup_prog
......
......@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
......@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase):
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_pp_diff_process_mesh(self):
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
......@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
......@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
......
......@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
......@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context.block_state.parse_forward_blocks(complete_train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
dist_context, partitioned_params_grads)
resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
rank_id, dist_context, partitioned_params_grads)
resharder.reshard()
# the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog))
......
......@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册