未验证 提交 2747de2b 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Update reshard for while sub block (#40366)

* update reshard for while sub block

* fix code format error
上级 575dea8f
...@@ -29,6 +29,7 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map ...@@ -29,6 +29,7 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded. # NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
while_block_info = {}
class AllGatherOpDesc: class AllGatherOpDesc:
...@@ -280,8 +281,20 @@ def _is_overlapped(shape_x, shape_y): ...@@ -280,8 +281,20 @@ def _is_overlapped(shape_x, shape_y):
return overlapped return overlapped
def _need_reshard(dist_tensor, dist_op, op_input=True): def _need_reshard(dist_tensor,
dist_op,
actual_process_mesh,
program,
dist_context,
op_input=True):
"""Judge the tensor whether needs to be resharded.""" """Judge the tensor whether needs to be resharded."""
def _is_unshard(dims_mapping):
for dim in dims_mapping:
if dim != -1:
return False
return True
is_reshard = False is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name tensor_name = dist_tensor.serial_tensor.name
...@@ -289,32 +302,74 @@ def _need_reshard(dist_tensor, dist_op, op_input=True): ...@@ -289,32 +302,74 @@ def _need_reshard(dist_tensor, dist_op, op_input=True):
tensor_process_mesh = tensor_dist_attr.process_mesh tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh op_process_mesh = actual_process_mesh
if op_input: if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if all( if all(
map(lambda x: x is not None, [ map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh, tensor_dims_mapping, tensor_process_mesh,
op_input_dims_mapping, op_process_mesh op_input_dims_mapping, op_process_mesh
])): ])):
if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh: # dims_mapping
is_reshard = True if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = program.blocks[dist_op.serial_op.attr(
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert _is_unshard(tensor_dims_mapping)
assert _is_unshard(op_input_dims_mapping)
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
# for while op, it should find the process mesh of op actually used the tensor as input
if dist_op.serial_op.type == "while":
sub_block = program.blocks[dist_op.serial_op.attr(
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = dist_context.get_dist_op_for_program(
op).dist_attr
process_mesh = dist_op_attr.process_mesh
if process_mesh == op_process_mesh:
is_reshard = True
break
else:
is_reshard = True
else: else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping( op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name) tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if all( if all(
map(lambda x: x is not None, [ map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh, tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh op_output_dims_mapping, op_process_mesh
])): ])):
if tensor_process_mesh != op_process_mesh: if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping: if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError( raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping." "It is not supported that tensor dims mapping is different from op output dims mapping."
) )
return is_reshard return is_reshard
...@@ -329,13 +384,14 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping): ...@@ -329,13 +384,14 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
return complete_shape return complete_shape
def find_op_desc_seq(dist_tensor, dist_op): def find_op_desc_seq(dist_tensor, dist_op, actual_process_mesh, batch_size):
""" """
Find the op description sequence to reshard the source tensor for matching the op requirement. Find the op description sequence to reshard the source tensor for matching the op requirement.
Args: Args:
dist_tensor (DistributedTensor): A distributed tensor. dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator. dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
Returns: Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
...@@ -350,11 +406,16 @@ def find_op_desc_seq(dist_tensor, dist_op): ...@@ -350,11 +406,16 @@ def find_op_desc_seq(dist_tensor, dist_op):
source_process_shape = source_process_mesh.topology source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
target_process_mesh = op_dist_attr.process_mesh target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0:
new_shape = list(source_tensor.shape)
new_shape[0] = batch_size
source_tensor.desc.set_shape(new_shape)
complete_shape = _compute_complete_shape( complete_shape = _compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping) source_tensor.shape, source_process_shape, source_dims_mapping)
op_desc_seq = {} op_desc_seq = {}
...@@ -503,7 +564,7 @@ def find_op_desc_seq(dist_tensor, dist_op): ...@@ -503,7 +564,7 @@ def find_op_desc_seq(dist_tensor, dist_op):
return op_desc_seq return op_desc_seq
def _insert_send_op(block, idx, tensor, dst): def _insert_send_op(block, idx, tensor, dst, op_role):
"""Insert send op into block at the given index.""" """Insert send op into block at the given index."""
op_type = 'send_v2' op_type = 'send_v2'
block._insert_op( block._insert_op(
...@@ -514,10 +575,11 @@ def _insert_send_op(block, idx, tensor, dst): ...@@ -514,10 +575,11 @@ def _insert_send_op(block, idx, tensor, dst):
'ring_id': 0, 'ring_id': 0,
'peer': dst, 'peer': dst,
'use_calc_stream': True, 'use_calc_stream': True,
'op_role': op_role
}) })
def _insert_recv_op(block, idx, tensor, src): def _insert_recv_op(block, idx, tensor, src, op_role):
"""Insert recv op into block at the given index.""" """Insert recv op into block at the given index."""
op_type = 'recv_v2' op_type = 'recv_v2'
block._insert_op( block._insert_op(
...@@ -531,14 +593,16 @@ def _insert_recv_op(block, idx, tensor, src): ...@@ -531,14 +593,16 @@ def _insert_recv_op(block, idx, tensor, src):
'out_shape': tensor.shape, 'out_shape': tensor.shape,
'dtype': tensor.dtype, 'dtype': tensor.dtype,
'use_calc_stream': True, 'use_calc_stream': True,
'op_role': op_role
}) })
def _insert_concat_op(block, idx, tensors, axis): def _insert_concat_op(block, idx, tensors, axis, op_role):
"""Insert concat op into block at the given block.""" """Insert concat op into block at the given block."""
inputs = {'X': tensors} inputs = {'X': tensors}
attrs = {} attrs = {}
attrs['axis'] = axis attrs['axis'] = axis
attrs['op_role'] = op_role
helper = LayerHelper('concat', **locals()) helper = LayerHelper('concat', **locals())
with paddle.static.program_guard(block.program): with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
...@@ -548,7 +612,8 @@ def _insert_concat_op(block, idx, tensors, axis): ...@@ -548,7 +612,8 @@ def _insert_concat_op(block, idx, tensors, axis):
return out return out
def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name): def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_role):
"""Insert slice op into block at the given block.""" """Insert slice op into block at the given block."""
inputs = {'Input': tensor} inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes))) infer_flags = list(1 for i in range(len(axes)))
...@@ -556,24 +621,23 @@ def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name): ...@@ -556,24 +621,23 @@ def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name):
"axes": axes, "axes": axes,
"starts": starts, "starts": starts,
"ends": ends, "ends": ends,
"infer_flags": infer_flags "infer_flags": infer_flags,
'op_role': op_role
} }
helper = LayerHelper('slice', **locals()) helper = LayerHelper('slice', **locals())
out = block.create_var( out = block.create_var(
name=new_var_name, name=new_var_name, dtype=tensor.dtype, type=tensor.type)
dtype=tensor.dtype,
type=core.VarDesc.VarType.LOD_TENSOR)
block._insert_op( block._insert_op(
idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs) idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out return out
def _insert_split_op(block, idx, tensor, num_or_sections): def _insert_split_op(block, idx, tensor, num_or_sections, op_role):
"""Insert split op into block at the given index.""" """Insert split op into block at the given index."""
helper = LayerHelper('split', **locals()) helper = LayerHelper('split', **locals())
input_shape = tensor.shape input_shape = tensor.shape
inputs = {'X': tensor} inputs = {'X': tensor}
attrs = {'num': num_or_sections, "axis": 0} attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role}
with paddle.static.program_guard(block.program): with paddle.static.program_guard(block.program):
outs = [ outs = [
helper.create_variable_for_type_inference( helper.create_variable_for_type_inference(
...@@ -584,7 +648,7 @@ def _insert_split_op(block, idx, tensor, num_or_sections): ...@@ -584,7 +648,7 @@ def _insert_split_op(block, idx, tensor, num_or_sections):
return outs return outs
def _insert_allgather_op(block, idx, tensor, ranks): def _insert_allgather_op(block, idx, tensor, ranks, op_role):
"""Insert allgather op into block at the given index.""" """Insert allgather op into block at the given index."""
def _insert_fill_constant_op(block, idx): def _insert_fill_constant_op(block, idx):
...@@ -597,6 +661,7 @@ def _insert_allgather_op(block, idx, tensor, ranks): ...@@ -597,6 +661,7 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs['str_value'] = str(int("1")) attrs['str_value'] = str(int("1"))
attrs['value'] = int("1") attrs['value'] = int("1")
attrs['dtype'] = out.dtype attrs['dtype'] = out.dtype
attrs['op_role'] = op_role
utils.get_shape_tensor_inputs( utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant') inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant')
block._insert_op( block._insert_op(
...@@ -625,14 +690,16 @@ def _insert_allgather_op(block, idx, tensor, ranks): ...@@ -625,14 +690,16 @@ def _insert_allgather_op(block, idx, tensor, ranks):
inputs={'X': [fill_constant_out]}, inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]}, outputs={'Out': [fill_constant_out]},
attrs={'ring_id': 0, attrs={'ring_id': 0,
'use_calc_stream': True}) 'use_calc_stream': True,
'op_role': op_role})
# insert c_sync_calc_stream op # insert c_sync_calc_stream op
block._insert_op( block._insert_op(
idx + 2, idx + 2,
type="c_sync_calc_stream", type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]}, inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]}) outputs={'Out': [fill_constant_out]},
attrs={'op_role': op_role})
idx_offset = 3 idx_offset = 3
# insert c_allgather op # insert c_allgather op
...@@ -649,20 +716,21 @@ def _insert_allgather_op(block, idx, tensor, ranks): ...@@ -649,20 +716,21 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs={ attrs={
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'nranks': group.nranks 'nranks': group.nranks,
'op_role': op_role
}) })
idx_offset += 1 idx_offset += 1
# insert split op # insert split op
split_out = _insert_split_op(block, idx + idx_offset, allgather_out, split_out = _insert_split_op(block, idx + idx_offset, allgather_out,
group.nranks) group.nranks, op_role)
idx_offset += 1 idx_offset += 1
tensor_list.extend(split_out) tensor_list.extend(split_out)
return tensor_list, idx_offset return tensor_list, idx_offset
def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
block, idx): block, idx, op_role):
"""Concat the tensors and insert concat op.""" """Concat the tensors and insert concat op."""
if not partition_tensor_list: if not partition_tensor_list:
partition_tensor_list.append((tensor, partition_index)) partition_tensor_list.append((tensor, partition_index))
...@@ -674,13 +742,13 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, ...@@ -674,13 +742,13 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
partition_tensor_list[i][1], partition_index) partition_tensor_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
has_concat = True has_concat = True
_ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis) \ _ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \
if first_order == 0 else \ if first_order == 0 else \
_insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis) _insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role)
partition_tensor_list.pop(i) partition_tensor_list.pop(i)
idx[0] += 1 idx[0] += 1
_concat_partitions_with_op(partition_tensor_list, _, _concat_partitions_with_op(partition_tensor_list, _,
new_partition, block, idx) new_partition, block, idx, op_role)
break break
i += 1 i += 1
if not has_concat: if not has_concat:
...@@ -692,8 +760,47 @@ HAS_RECV = {} ...@@ -692,8 +760,47 @@ HAS_RECV = {}
HAS_ALLGATHER = {} HAS_ALLGATHER = {}
def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, def _get_while_op_actual_process_mesh(op, program, rank_id, dist_context):
dist_context): """Get the while op actual Process mesh corresponding to rank"""
assert op.type == "while"
while_op_process_mesh = dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
sub_block = program.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue
if rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break
if actual_process_mesh is None and rank_id in while_op_process_mesh.processes:
actual_process_mesh = while_op_process_mesh
assert actual_process_mesh is not None
return actual_process_mesh
def _get_var(var_name, block, program):
"""Get var in the parent block if not found in the current block"""
var = None
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
return var
def parse_op_desc(block, rank_id, op_desc_seq, var_name, reshard_op,
dist_context, program, actual_process_mesh):
"""Parse op desc sequence and insert op in the block""" """Parse op desc sequence and insert op in the block"""
global HAS_SENT global HAS_SENT
global HAS_RECV global HAS_RECV
...@@ -703,9 +810,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -703,9 +810,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if rank_id not in op_desc_seq.keys(): if rank_id not in op_desc_seq.keys():
return return
op_desc_list = op_desc_seq[rank_id] op_desc_list = op_desc_seq[rank_id]
block = program.global_block()
assert var_name in block.vars.keys(
), "The {} cannot be found in the {} program.".format(var_name, rank_id)
idx = None idx = None
for index, op in list(enumerate(block.ops)): for index, op in list(enumerate(block.ops)):
...@@ -716,7 +820,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -716,7 +820,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
rank_id) rank_id)
matched_op = block.ops[idx] matched_op = block.ops[idx]
source_tensor = block.vars[var_name] source_tensor = _get_var(var_name, block, program)
for op_desc in op_desc_list: for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc): # noqa: F401 if isinstance(op_desc, AllGatherOpDesc): # noqa: F401
if var_name not in HAS_ALLGATHER.keys(): if var_name not in HAS_ALLGATHER.keys():
...@@ -724,7 +828,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -724,7 +828,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if not HAS_ALLGATHER[var_name] or op_desc.group not in list( if not HAS_ALLGATHER[var_name] or op_desc.group not in list(
map(lambda x: x[0], HAS_ALLGATHER[var_name])): map(lambda x: x[0], HAS_ALLGATHER[var_name])):
tensor_list, idx_offset = _insert_allgather_op( tensor_list, idx_offset = _insert_allgather_op(
block, idx, source_tensor, op_desc.group) block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset idx += idx_offset
tensor_name_list = [var.name for var in tensor_list] tensor_name_list = [var.name for var in tensor_list]
HAS_ALLGATHER[var_name].append( HAS_ALLGATHER[var_name].append(
...@@ -743,7 +848,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -743,7 +848,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if var_name not in HAS_SENT.keys(): if var_name not in HAS_SENT.keys():
HAS_SENT[var_name] = [] HAS_SENT[var_name] = []
if op_desc.dst not in HAS_SENT[var_name]: if op_desc.dst not in HAS_SENT[var_name]:
_insert_send_op(block, idx, source_tensor, op_desc.dst) _insert_send_op(block, idx, source_tensor, op_desc.dst,
reshard_op.attr('op_role'))
idx += 1 idx += 1
HAS_SENT[var_name].append(op_desc.dst) HAS_SENT[var_name].append(op_desc.dst)
...@@ -758,8 +864,10 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -758,8 +864,10 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
recv_tensor = block.create_var( recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"), name=unique_name.generate(var_name + "@recv"),
shape=shape, shape=shape,
dtype=source_tensor.dtype) dtype=source_tensor.dtype,
_insert_recv_op(block, idx, recv_tensor, op_desc.src) type=source_tensor.type)
_insert_recv_op(block, idx, recv_tensor, op_desc.src,
reshard_op.attr('op_role'))
tensor_list.append(recv_tensor) tensor_list.append(recv_tensor)
idx += 1 idx += 1
HAS_RECV[var_name][op_desc.src] = recv_tensor HAS_RECV[var_name][op_desc.src] = recv_tensor
...@@ -772,7 +880,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -772,7 +880,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
for index, tensor in enumerate(tensor_list): for index, tensor in enumerate(tensor_list):
_concat_partitions_with_op(partition_tensor_list, tensor, _concat_partitions_with_op(partition_tensor_list, tensor,
partition_index_list[index], block, partition_index_list[index], block,
idx_list) idx_list, reshard_op.attr('op_role'))
idx = idx_list[0] idx = idx_list[0]
elif isinstance(op_desc, SliceOpDesc): elif isinstance(op_desc, SliceOpDesc):
...@@ -787,11 +895,11 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -787,11 +895,11 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
starts=op_desc.starts, starts=op_desc.starts,
ends=op_desc.ends, ends=op_desc.ends,
axes=op_desc.axes, axes=op_desc.axes,
new_var_name=new_name) new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
tensor_attr = TensorDistributedAttribute() tensor_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_op_dist_attr_for_program( process_mesh = actual_process_mesh
matched_op).process_mesh
dims_mapping = dist_context.get_op_dist_attr_for_program( dims_mapping = dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name) matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping tensor_attr.dims_mapping = dims_mapping
...@@ -799,11 +907,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -799,11 +907,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
dist_context.set_tensor_dist_attr_for_program(target_tensor, dist_context.set_tensor_dist_attr_for_program(target_tensor,
tensor_attr) tensor_attr)
if op.type == "while":
global while_block_info
# var_reshard_mapping means the while op input need be changed to
if "var_reshard_mapping" not in while_block_info[op.attr(
"sub_block").id].keys():
while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"] = {}
while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = target_tensor.name
# rename op input name according to new name # rename op input name according to new name
for op in block.ops: for op in block.ops:
for name in op.input_arg_names: for name in op.input_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if name == var_name and op_dist_attr is not None: if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id():
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh = op_dist_attr.process_mesh op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name) var_name)
...@@ -819,102 +945,166 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): ...@@ -819,102 +945,166 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
not_remove_op_ref = [ not_remove_op_ref = [
"create_py_reader", "create_double_buffer_reader", "read" "create_py_reader", "create_double_buffer_reader", "read"
] ]
remove_op_idx = [] global while_block_info
block = auto_parallel_main_prog.global_block()
ops = block.ops # NOTE: The nested sub block is not be supported now.
vars = block.vars remove_block_order = []
for idx, op in enumerate(ops): for block_idx in while_block_info:
# handle read op in the pipeline scene specially, it will be removed in the future. remove_block_order.append(block_idx)
if op.type == "read":
dim_list = [] for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
for var_name in op.output_arg_names: if block_idx not in remove_block_order:
dim_list.extend(vars[var_name].shape) remove_block_order.append(block_idx)
for i in range(idx, -1, -1):
if ops[i].type == "create_py_reader": # the sub block should be removed first
ops[i]._set_attr("shape_concat", dim_list) for block_idx in remove_block_order:
break remove_op_idx = []
continue block = auto_parallel_main_prog.blocks[block_idx]
ops = block.ops
# replace the input and output of c_sync_comm_stream op when in pipeline scene. vars = block.vars
if op.type == "c_sync_comm_stream": for idx, op in enumerate(ops):
need_save = [] if op.type == "read":
for var_name in op.input_arg_names: dim_list = []
process_mesh = dist_context.get_tensor_dist_attr_for_program( for var_name in op.output_arg_names:
vars[var_name]).process_mesh dim_list.extend(
if rank_id in process_mesh.processes: _get_var(var_name, block, auto_parallel_main_prog)
need_save.append(var_name) .shape)
if not need_save: for i in range(idx, -1, -1):
remove_op_idx.append(idx) if ops[i].type == "create_py_reader":
ops[i]._set_attr("shape_concat", dim_list)
break
continue continue
proto = OpProtoHolder.instance().get_op_proto(op.type) # replace the input and output of c_sync_comm_stream op when in pipeline scene.
op.desc.set_input(proto.inputs[0].name, need_save) if op.type == "c_sync_comm_stream":
op.desc.set_output(proto.outputs[0].name, need_save) need_save = []
continue for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
_get_var(var_name, block,
auto_parallel_main_prog)).process_mesh
if rank_id in process_mesh.processes:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
continue
# judge the other op whether should be removed. proto = OpProtoHolder.instance().get_op_proto(op.type)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op.desc.set_input(proto.inputs[0].name, need_save)
if op_dist_attr is not None: op.desc.set_output(proto.outputs[0].name, need_save)
op_process_mesh = op_dist_attr.process_mesh continue
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]: # judge the other op whether should be removed.
block._remove_op(idx) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is not None:
op_process_mesh = op_dist_attr.process_mesh
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
block._remove_op(idx)
def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads): def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads):
"""Remove no need vars in the main program""" """Remove no need vars in the main program"""
remove_vars = set() for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
block = auto_parallel_main_prog.global_block() remove_vars = set()
ops = block.ops ops = block.ops
vars = block.vars vars = block.vars
need_vars = set() need_vars = set()
for op in ops: for op in ops:
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
if var_name in vars: if var_name in vars:
need_vars.add(var_name) need_vars.add(var_name)
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
if var_name in vars: if var_name in vars:
need_vars.add(var_name) need_vars.add(var_name)
for var in vars: for var in vars:
if var not in need_vars: if var not in need_vars:
remove_vars.add(var) remove_vars.add(var)
# change dist_params_grads # change dist_params_grads, the optimize op just in block 0.
param_grad_map = {} if block_idx == 0:
for op in ops: param_grad_map = {}
if int(op.attr('op_role')) == int(OpRole.Optimize): for op in ops:
if "Param" in op.input_names and "Grad" in op.input_names: if int(op.attr('op_role')) == int(OpRole.Optimize):
param_name = op.input("Param")[0] if "Param" in op.input_names and "Grad" in op.input_names:
grad_name = op.input("Grad")[0] param_name = op.input("Param")[0]
param_grad_map[param_name] = grad_name grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
need_remove_idx = []
for idx, item in enumerate(dist_params_grads): need_remove_idx = []
if item[0].name not in param_grad_map.keys(): for idx, item in enumerate(dist_params_grads):
need_remove_idx.append(idx) if item[0].name not in param_grad_map.keys():
need_remove_idx.append(idx)
for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx) for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx)
idx = 0
while idx < len(dist_params_grads): idx = 0
param_name = dist_params_grads[idx][0].name while idx < len(dist_params_grads):
grad_name = dist_params_grads[idx][1].name param_name = dist_params_grads[idx][0].name
if grad_name != param_grad_map[param_name]: grad_name = dist_params_grads[idx][1].name
dist_params_grads[idx] = (vars[param_name], if grad_name != param_grad_map[param_name]:
vars[param_grad_map[param_name]]) dist_params_grads[idx] = (vars[param_name],
idx += 1 vars[param_grad_map[param_name]])
idx += 1
for var in remove_vars: for var in remove_vars:
block._remove_var(var) block._remove_var(var)
def _change_while_op_input_and_output(auto_parallel_main_prog, dist_context):
"""Change while op input and output after the corresponding sub block ops removed"""
global while_block_info
for sub_block_idx in while_block_info:
sub_block = auto_parallel_main_prog.blocks[sub_block_idx]
parent_while_op_id = while_block_info[sub_block_idx]["op_id"]
parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx]
sub_block_op_inputs = set()
sub_block_op_outputs = []
for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name)
for var_name in op.input_arg_names:
sub_block_op_inputs.add(var_name)
# find the while op
while_op = None
for op in parent_block.ops:
if op.desc.id() == parent_while_op_id and op.type == "while":
while_op = op
break
assert while_op is not None
# find the actual input and output of while op
proto = OpProtoHolder.instance().get_op_proto(while_op.type)
new_X = []
for var_name in while_op.input("X"):
if var_name in sub_block_op_inputs:
new_X.append(var_name)
assert new_X
while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads): dist_params_grads):
"""Remove no need vars and ops in the main program.""" """Remove no need vars and ops in the main program."""
_remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id) _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id)
_change_while_op_input_and_output(auto_parallel_main_prog, dist_context)
_remove_no_need_vars(auto_parallel_main_prog, dist_params_grads) _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
...@@ -992,8 +1182,70 @@ def remove_no_need_in_startup(auto_parallel_main_prog, ...@@ -992,8 +1182,70 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
startup_block._remove_op(idx) startup_block._remove_op(idx)
def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, def _get_process_meshes(op, program, dist_context):
dist_context, dist_params_grads): """Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = program.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def _is_condition_replicative(op, program, dist_context):
assert op.type == "while"
sub_block = program.blocks[op.attr("sub_block").id]
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
# the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"):
var = _get_var(var_name, sub_block, program)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
tensor_dist_attr = dist_tensor.dist_attr
var_dims_mapping = tensor_dist_attr.dims_mapping
for dim in var_dims_mapping:
if dim != -1:
return False
return True
def _get_op_process_meshes(op, dist_context):
process_meshes = []
dist_op = dist_context.get_dist_op_for_program(op)
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in dist_context.process_meshes:
if set(process_mesh.processes) & (
set(op_process_mesh.processes)
) and len(process_mesh.processes) <= len(op_process_mesh.processes):
process_meshes.append(process_mesh)
# it means the process mesh is not a union when process meshes is null
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def reshard(auto_parallel_main_prog,
auto_parallel_startup_prog,
rank_id,
dist_context,
dist_params_grads,
batch_size=None):
""" """
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
...@@ -1019,65 +1271,137 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -1019,65 +1271,137 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
return True return True
return False return False
block = auto_parallel_main_prog.global_block() global while_block_info
idx = 0 for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
while idx < len(block.ops): if block_idx in while_block_info:
pre_op_count = len(block.ops) if "var_reshard_mapping" in while_block_info[block_idx]:
op = block.ops[idx] var_reshard_mapping = while_block_info[block_idx][
"var_reshard_mapping"]
for op in block.ops:
for var_name in op.input_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_input(var_name,
var_reshard_mapping[var_name])
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
op_dist_attr.set_input_dims_mapping(
var_reshard_mapping[var_name], dims_mapping)
op_dist_attr.set_input_dist_attr(var_name, None)
# the outputs also need to be renamed when the output name is the same with input name
for var_name in op.output_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_output(
var_name, var_reshard_mapping[var_name])
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
op_dist_attr.set_output_dims_mapping(
var_reshard_mapping[var_name], dims_mapping)
op_dist_attr.set_output_dist_attr(var_name,
None)
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
if _is_special_op(op):
idx += 1
continue
if _is_special_op(op): dist_op = dist_context.get_dist_op_for_program(op)
idx += 1 if dist_op is not None:
continue process_meshes = []
if op.type == "while":
if not _is_condition_replicative(
op, auto_parallel_main_prog, dist_context):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
)
process_meshes = _get_process_meshes(
op, auto_parallel_main_prog, dist_context)
assert process_meshes
if op.attr("sub_block").id not in while_block_info:
while_block_info[op.attr("sub_block").id] = {}
while_block_info[op.attr("sub_block").id][
"op_id"] = op.desc.id()
while_block_info[op.attr("sub_block").id][
"actual_process_mesh"] = _get_while_op_actual_process_mesh(
op, auto_parallel_main_prog, rank_id, dist_context)
else:
process_meshes = _get_op_process_meshes(op, dist_context)
input_vars = None
if op.type == "while":
input_var_names = op.input("X")
else:
input_var_names = op.input_arg_names
idx_offset = 0
for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = _get_var(var_name, block, auto_parallel_main_prog)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
for process_mesh in process_meshes:
if dist_tensor is not None and _need_reshard(
dist_tensor, dist_op, process_mesh,
auto_parallel_main_prog, dist_context):
reshard_op_desc = find_op_desc_seq(
dist_tensor, dist_op, process_mesh, batch_size)
parse_op_desc(block, rank_id, reshard_op_desc,
var_name, op, dist_context,
auto_parallel_main_prog, process_mesh)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
dist_op = dist_context.get_dist_op_for_program(op) # insert send and recv op if output process mesh is different from tensor process mesh
if dist_op is not None: idx = 0
idx_offset = 0 # skip reader and ops whose process mesh is union
for var_name in op.input_arg_names: skip_ops = [
# skip lod_tensor_blocking_queue_0 "create_py_reader", "create_double_buffer_reader", "read", "while",
if var_name == "lod_tensor_blocking_queue_0": "write_to_array", "read_from_array"
continue ]
var = block.vars[var_name] skip_ops += _g_special_ops
dist_tensor = dist_context.get_dist_tensor_for_program(var) while idx < len(block.ops):
if dist_tensor is not None and _need_reshard(dist_tensor, pre_op_count = len(block.ops)
dist_op): op = block.ops[idx]
reshard_op_desc = find_op_desc_seq(dist_tensor, dist_op) dist_op = dist_context.get_dist_op_for_program(op)
parse_op_desc(auto_parallel_main_prog, rank_id, if dist_op is not None and op.type not in skip_ops:
reshard_op_desc, var_name, op, dist_context) for var_name in op.output_arg_names:
cur_op_count = len(block.ops) var = _get_var(var_name, block, auto_parallel_main_prog)
idx_offset = idx_offset + cur_op_count - pre_op_count dist_tensor = dist_context.get_dist_tensor_for_program(var)
pre_op_count = cur_op_count process_mesh = dist_op.dist_attr.process_mesh
idx = idx + idx_offset + 1 if dist_tensor is not None and _need_reshard(
else: dist_tensor, dist_op, process_mesh,
idx += 1 auto_parallel_main_prog, dist_context, False):
for index, item in enumerate(
# insert send and recv op if output process mesh is different from tensor process mesh dist_op.dist_attr.process_mesh.processes):
idx = 0 recv_rank = dist_tensor.dist_attr.process_mesh.processes[
skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"] index]
skip_ops += _g_special_ops if rank_id == item:
while idx < len(block.ops): _insert_send_op(block, idx + 1, var, recv_rank,
pre_op_count = len(block.ops) op.attr('op_role'))
op = block.ops[idx] if rank_id == recv_rank:
dist_op = dist_context.get_dist_op_for_program(op) _insert_recv_op(block, idx + 1, var, item,
if dist_op is not None and op.type not in skip_ops: op.attr('op_role'))
for var_name in op.output_arg_names: cur_op_count = len(block.ops)
var = block.vars[var_name] idx_offset = idx_offset + cur_op_count - pre_op_count
dist_tensor = dist_context.get_dist_tensor_for_program(var) pre_op_count = cur_op_count
if dist_tensor is not None and _need_reshard(dist_tensor, idx = idx + idx_offset + 1
dist_op, False): else:
for index, item in enumerate( idx += 1
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
if rank_id == item:
_insert_send_op(block, idx + 1, var, recv_rank)
if rank_id == recv_rank:
_insert_recv_op(block, idx + 1, var, item)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
# remove no need vars and ops in the main program # remove no need vars and ops in the main program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
......
...@@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer ...@@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program
from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program
from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
...@@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase): ...@@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase):
str(paddle.distributed.get_rank()))) str(paddle.distributed.get_rank())))
def test_mlp_mp2pp(self): def test_mlp_mp2pp(self):
set_default_distributed_context(None)
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
...@@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase): ...@@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase):
fetch_list=[loss]) fetch_list=[loss])
last_res = res[0] last_res = res[0]
set_default_distributed_context(None)
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1])
global PP_MESH_0 global PP_MESH_0
...@@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase): ...@@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
str(paddle.distributed.get_rank()))) str(paddle.distributed.get_rank())))
def test_mlp_pp2mp(self): def test_mlp_pp2mp(self):
set_default_distributed_context(None)
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
global _global_process_mesh global _global_process_mesh
...@@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase): ...@@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
if paddle.distributed.get_rank() in [1]: if paddle.distributed.get_rank() in [1]:
last_res = res[0] last_res = res[0]
set_default_distributed_context(None)
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1])
...@@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase): ...@@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
np.random.seed(2021) np.random.seed(2021)
def test_input_invalid(self): def test_input_invalid(self):
set_default_distributed_context(None)
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册