未验证 提交 d101334c 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard (#40865)

* fix code stype

* update unitest
上级 86554d91
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .reshard import reshard # noqa: F401 from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost from .cost_model import estimate_cost
__all__ = [] __all__ = []
...@@ -235,19 +235,19 @@ class Converter(object): ...@@ -235,19 +235,19 @@ class Converter(object):
@staticmethod @staticmethod
def merge_with_dist_attr(tensor_list, dist_attr): def merge_with_dist_attr(tensor_list, dist_attr):
""" Merge tensor with distributed attribute """ """ Merge tensor with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# get the complete shape of the tensor # get the complete shape of the tensor
complete_shape = _compute_complete_shape(tensor_list[0].shape, complete_shape = Resharder.compute_complete_shape(
process_shape, dims_mapping) tensor_list[0].shape, process_shape, dims_mapping)
# merge the tensor with dist_attr # merge the tensor with dist_attr
partition_tensor_list = [] partition_tensor_list = []
merged_partiton = [] merged_partiton = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process, complete_shape, dims_mapping, process_shape,
process_group) process_group)
index = process_group.index(process) index = process_group.index(process)
...@@ -302,7 +302,7 @@ class Converter(object): ...@@ -302,7 +302,7 @@ class Converter(object):
_merge_tensor(partition_tensor_list, tensor, partition_index) _merge_tensor(partition_tensor_list, tensor, partition_index)
# partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] # partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
""" """
from .reshard import _compute_concat_info from .reshard import Resharder
if len(partition_tensor_list) == 1: if len(partition_tensor_list) == 1:
is_complete_data = True is_complete_data = True
...@@ -318,7 +318,7 @@ class Converter(object): ...@@ -318,7 +318,7 @@ class Converter(object):
else: else:
i = 0 i = 0
while i < len(partition_tensor_list): while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i][1], partition_index) partition_tensor_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
if first_order == 0: if first_order == 0:
...@@ -391,11 +391,11 @@ class Converter(object): ...@@ -391,11 +391,11 @@ class Converter(object):
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]] # index: [[], [], [2, 4]]
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
split_indices_list = [] split_indices_list = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process, complete_shape, dims_mapping, process_shape,
process_group) process_group)
if split_indices_list: if split_indices_list:
...@@ -437,9 +437,9 @@ class Converter(object): ...@@ -437,9 +437,9 @@ class Converter(object):
process_shape, process_group) process_shape, process_group)
# index: 2 # index: 2
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
rank_id, complete_shape, dims_mapping, process_shape, process_group) rank_id, complete_shape, dims_mapping, process_shape, process_group)
sliced_index = 0 sliced_index = 0
for i, shape in enumerate(complete_shape): for i, shape in enumerate(complete_shape):
......
...@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger ...@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger
from .mapper import mapping from .mapper import mapping
from .cluster import Cluster from .cluster import Cluster
from .reshard import reshard from .reshard import Resharder
from .planner import Planner from .planner import Planner
from .completion import Completer from .completion import Completer
from .partitioner import Partitioner from .partitioner import Partitioner
...@@ -187,8 +187,9 @@ class Engine: ...@@ -187,8 +187,9 @@ class Engine:
# Do reshard process # Do reshard process
set_grad_var_shape(dist_main_prog, dist_context) set_grad_var_shape(dist_main_prog, dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# Apply post optimization passes # Apply post optimization passes
self._apply_post_optimization(dist_main_prog, dist_startup_prog, self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
...@@ -199,8 +200,9 @@ class Engine: ...@@ -199,8 +200,9 @@ class Engine:
serial_main_program, serial_startup_program, []) serial_main_program, serial_startup_program, [])
# Do reshard process # Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, [], resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
1) dist_context, [], 1)
resharder.reshard()
# clone program for test # clone program for test
if mode != 'train': if mode != 'train':
......
...@@ -42,7 +42,7 @@ from .utils import make_data_unshard ...@@ -42,7 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from .reshard import Resharder
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -213,17 +213,15 @@ class AutoParallelizer: ...@@ -213,17 +213,15 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_params_grads) self._dist_context, dist_params_grads)
resharder.reshard()
self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
g_process_group_map = None g_process_group_map = None
if not relaunch_phase: if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map) g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear() _g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes: for process_mesh in dist_context._process_meshes:
......
...@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
def _merge_parameter_with_dist_attr(param_list, dist_attr): def _merge_parameter_with_dist_attr(param_list, dist_attr):
""" Merge parameter with distributed attribute """ """ Merge parameter with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# get the complete shape of the parameter # get the complete shape of the parameter
complete_shape = _compute_complete_shape(param_list[0].shape, process_shape, complete_shape = Resharder.compute_complete_shape(
dims_mapping) param_list[0].shape, process_shape, dims_mapping)
# merge the parameter with dist_attr # merge the parameter with dist_attr
partition_param_list = [] partition_param_list = []
merged_partiton = [] merged_partiton = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group)
index = process_group.index(process) index = process_group.index(process)
if partition_index not in merged_partiton: if partition_index not in merged_partiton:
...@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index, ...@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
_merge_parameter(partition_param_list, param, partition_index) _merge_parameter(partition_param_list, param, partition_index)
# partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] # partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
""" """
from .reshard import _compute_concat_info from .reshard import Resharder
if len(partition_param_list) == 1: if len(partition_param_list) == 1:
is_complete_data = True is_complete_data = True
...@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index, ...@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
else: else:
i = 0 i = 0
while i < len(partition_param_list): while i < len(partition_param_list):
concat_axis, first_order, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_param_list[i][1], partition_index) partition_param_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
if first_order == 0: if first_order == 0:
...@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, ...@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_shape, process_group) process_shape, process_group)
# index: 2 # index: 2
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
rank, complete_shape, dims_mapping, process_shape, process_group) rank, complete_shape, dims_mapping, process_shape, process_group)
sliced_param_index = 0 sliced_param_index = 0
for i, shape in enumerate(complete_shape): for i, shape in enumerate(complete_shape):
...@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape, ...@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]] # index: [[], [], [2, 4]]
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
split_indices_list = [] split_indices_list = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group)
if split_indices_list: if split_indices_list:
for dim in range(len(partition_index)): for dim in range(len(partition_index)):
......
...@@ -31,7 +31,6 @@ from paddle.distributed import fleet ...@@ -31,7 +31,6 @@ from paddle.distributed import fleet
from paddle.fluid.initializer import NumpyArrayInitializer from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program
from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program
from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static() paddle.enable_static()
...@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase): ...@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase):
paddle.seed(2021) paddle.seed(2021)
random.seed(2021) random.seed(2021)
np.random.seed(2021) np.random.seed(2021)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
def tearDown(self): def tearDown(self):
os.remove("./model_state_rank{}.pdmodel".format( os.remove("./model_state_rank{}.pdmodel".format(
......
...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.cost_model import estimate_cost from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase): ...@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog( distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(distributed_program, dist_startup_prog, rank_id, resharder = Resharder(distributed_program, dist_startup_prog,
dist_context, dist_params_grads) rank_id, dist_context, dist_params_grads)
resharder.reshard()
dist_program.append(distributed_program) dist_program.append(distributed_program)
cluster = None cluster = None
cost = estimate_cost( cost = estimate_cost(
......
...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer ...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import get_all_process_groups from paddle.distributed.auto_parallel.process_group import get_all_process_groups
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.cluster import Cluster
...@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads) dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_train_program, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
return dist_train_program, dist_startup_prog return dist_train_program, dist_startup_prog
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase):
self.assertTrue(check_initialization(dist_startup_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_pp_diff_process_mesh(self): def test_mlp_pp_diff_process_mesh(self):
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True) train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context) print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
...@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0 rank_id = 0
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# send and recv should not exist in dp scene. # send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# print_program_with_dist_attr(dist_main_prog, dist_context) # print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context.block_state.parse_forward_blocks(complete_train_program) dist_context.block_state.parse_forward_blocks(complete_train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
dist_context, partitioned_params_grads) rank_id, dist_context, partitioned_params_grads)
resharder.reshard()
# the x should not be slice # the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog)) self.assertTrue(check_allgather(partitioned_main_prog))
......
...@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto ...@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册