未验证 提交 529f1425 编写于 作者: C caozhou 提交者: GitHub

【Auto Parallel】Update reshard for complete (#39073)

* update reshard for newest completion

* update unitest

* merge newest
上级 0c3657ad
...@@ -279,7 +279,7 @@ def _is_overlapped(shape_x, shape_y): ...@@ -279,7 +279,7 @@ def _is_overlapped(shape_x, shape_y):
return overlapped return overlapped
def _need_reshard(dist_tensor, dist_op): def _need_reshard(dist_tensor, dist_op, op_input=True):
"""Judge the tensor whether needs to be resharded.""" """Judge the tensor whether needs to be resharded."""
is_reshard = False is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr tensor_dist_attr = dist_tensor.dist_attr
...@@ -287,15 +287,33 @@ def _need_reshard(dist_tensor, dist_op): ...@@ -287,15 +287,33 @@ def _need_reshard(dist_tensor, dist_op):
tensor_dims_mapping = tensor_dist_attr.dims_mapping tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_process_mesh = tensor_dist_attr.process_mesh tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh op_process_mesh = op_dist_attr.process_mesh
if all( if all(
map(lambda x: x is not None, [ map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, tensor_dims_mapping, tensor_process_mesh,
op_process_mesh op_input_dims_mapping, op_process_mesh
])): ])):
if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh: if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh:
is_reshard = True is_reshard = True
else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh
])):
if tensor_process_mesh != op_process_mesh:
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
return is_reshard return is_reshard
...@@ -948,12 +966,13 @@ def remove_no_need_in_startup(auto_parallel_main_prog, ...@@ -948,12 +966,13 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
dist_context): dist_context):
""" """
Reshard tensor in the program according to its dist attr and corresponding op dist attr. Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
Args: Args:
auto_parallel_main_prog (Program): An auto parallel main program. auto_parallel_main_prog (Program): An auto parallel main program.
auto_parallel_startup_prog (Program): An auto parallel startup program. auto_parallel_startup_prog (Program): An auto parallel startup program.
rank_id (int): The process id. rank_id (int): The process id.
dist_context (DistributedContext): The distributed context of this rank.
""" """
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog)) "but got {}.".format(type(auto_parallel_main_prog))
...@@ -1001,6 +1020,34 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -1001,6 +1020,34 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
else: else:
idx += 1 idx += 1
# insert send and recv op if output process mesh is different from tensor process mesh
idx = 0
skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops:
for var_name in op.output_arg_names:
var = block.vars[var_name]
dist_tensor = dist_context.get_dist_tensor_for_program(var)
if dist_tensor is not None and _need_reshard(dist_tensor,
dist_op, False):
for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
if rank_id == item:
_insert_send_op(block, idx + 1, var, recv_rank)
if rank_id == recv_rank:
_insert_recv_op(block, idx + 1, var, item)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
# remove no need vars and ops in the main program # remove no need vars and ops in the main program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id) remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id)
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.process_group import _g_process_group_map from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -143,7 +143,11 @@ def mlp_forward(train_program, start_program): ...@@ -143,7 +143,11 @@ def mlp_forward(train_program, start_program):
return loss, train_program, start_program return loss, train_program, start_program
def get_dist_prog(train_program, startup_program, dist_context, rank_id): def get_dist_prog(train_program,
startup_program,
dist_context,
rank_id,
change_process_mesh=False):
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
...@@ -157,6 +161,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -157,6 +161,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
if change_process_mesh:
global PP_MESH_1
dist_context.get_tensor_dist_attr_for_program(
train_program.global_block().vars[
"gelu_0.tmp_0"]).process_mesh = PP_MESH_1
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
startup_program, startup_program,
...@@ -308,6 +318,25 @@ class TestMLPReshard(unittest.TestCase): ...@@ -308,6 +318,25 @@ class TestMLPReshard(unittest.TestCase):
# parameter initialization of every rank should be different in the pipeline scene # parameter initialization of every rank should be different in the pipeline scene
self.assertTrue(check_initialization(dist_startup_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_pp_diff_process_mesh(self):
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_dp(self): def test_mlp_dp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册