未验证 提交 2747de2b 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Update reshard for while sub block (#40366)

* update reshard for while sub block

* fix code format error
上级 575dea8f
......@@ -29,6 +29,7 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
while_block_info = {}
class AllGatherOpDesc:
......@@ -280,8 +281,20 @@ def _is_overlapped(shape_x, shape_y):
return overlapped
def _need_reshard(dist_tensor, dist_op, op_input=True):
def _need_reshard(dist_tensor,
dist_op,
actual_process_mesh,
program,
dist_context,
op_input=True):
"""Judge the tensor whether needs to be resharded."""
def _is_unshard(dims_mapping):
for dim in dims_mapping:
if dim != -1:
return False
return True
is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name
......@@ -289,32 +302,74 @@ def _need_reshard(dist_tensor, dist_op, op_input=True):
tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh
op_process_mesh = actual_process_mesh
if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_input_dims_mapping, op_process_mesh
])):
if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh:
is_reshard = True
# dims_mapping
if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = program.blocks[dist_op.serial_op.attr(
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert _is_unshard(tensor_dims_mapping)
assert _is_unshard(op_input_dims_mapping)
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
# for while op, it should find the process mesh of op actually used the tensor as input
if dist_op.serial_op.type == "while":
sub_block = program.blocks[dist_op.serial_op.attr(
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = dist_context.get_dist_op_for_program(
op).dist_attr
process_mesh = dist_op_attr.process_mesh
if process_mesh == op_process_mesh:
is_reshard = True
break
else:
is_reshard = True
else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh
])):
if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
return is_reshard
......@@ -329,13 +384,14 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
return complete_shape
def find_op_desc_seq(dist_tensor, dist_op):
def find_op_desc_seq(dist_tensor, dist_op, actual_process_mesh, batch_size):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
......@@ -350,11 +406,16 @@ def find_op_desc_seq(dist_tensor, dist_op):
source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr
target_process_mesh = op_dist_attr.process_mesh
target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0:
new_shape = list(source_tensor.shape)
new_shape[0] = batch_size
source_tensor.desc.set_shape(new_shape)
complete_shape = _compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping)
op_desc_seq = {}
......@@ -503,7 +564,7 @@ def find_op_desc_seq(dist_tensor, dist_op):
return op_desc_seq
def _insert_send_op(block, idx, tensor, dst):
def _insert_send_op(block, idx, tensor, dst, op_role):
"""Insert send op into block at the given index."""
op_type = 'send_v2'
block._insert_op(
......@@ -514,10 +575,11 @@ def _insert_send_op(block, idx, tensor, dst):
'ring_id': 0,
'peer': dst,
'use_calc_stream': True,
'op_role': op_role
})
def _insert_recv_op(block, idx, tensor, src):
def _insert_recv_op(block, idx, tensor, src, op_role):
"""Insert recv op into block at the given index."""
op_type = 'recv_v2'
block._insert_op(
......@@ -531,14 +593,16 @@ def _insert_recv_op(block, idx, tensor, src):
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': True,
'op_role': op_role
})
def _insert_concat_op(block, idx, tensors, axis):
def _insert_concat_op(block, idx, tensors, axis, op_role):
"""Insert concat op into block at the given block."""
inputs = {'X': tensors}
attrs = {}
attrs['axis'] = axis
attrs['op_role'] = op_role
helper = LayerHelper('concat', **locals())
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(
......@@ -548,7 +612,8 @@ def _insert_concat_op(block, idx, tensors, axis):
return out
def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name):
def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_role):
"""Insert slice op into block at the given block."""
inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
......@@ -556,24 +621,23 @@ def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name):
"axes": axes,
"starts": starts,
"ends": ends,
"infer_flags": infer_flags
"infer_flags": infer_flags,
'op_role': op_role
}
helper = LayerHelper('slice', **locals())
out = block.create_var(
name=new_var_name,
dtype=tensor.dtype,
type=core.VarDesc.VarType.LOD_TENSOR)
name=new_var_name, dtype=tensor.dtype, type=tensor.type)
block._insert_op(
idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
def _insert_split_op(block, idx, tensor, num_or_sections):
def _insert_split_op(block, idx, tensor, num_or_sections, op_role):
"""Insert split op into block at the given index."""
helper = LayerHelper('split', **locals())
input_shape = tensor.shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, "axis": 0}
attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role}
with paddle.static.program_guard(block.program):
outs = [
helper.create_variable_for_type_inference(
......@@ -584,7 +648,7 @@ def _insert_split_op(block, idx, tensor, num_or_sections):
return outs
def _insert_allgather_op(block, idx, tensor, ranks):
def _insert_allgather_op(block, idx, tensor, ranks, op_role):
"""Insert allgather op into block at the given index."""
def _insert_fill_constant_op(block, idx):
......@@ -597,6 +661,7 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs['str_value'] = str(int("1"))
attrs['value'] = int("1")
attrs['dtype'] = out.dtype
attrs['op_role'] = op_role
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant')
block._insert_op(
......@@ -625,14 +690,16 @@ def _insert_allgather_op(block, idx, tensor, ranks):
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'ring_id': 0,
'use_calc_stream': True})
'use_calc_stream': True,
'op_role': op_role})
# insert c_sync_calc_stream op
block._insert_op(
idx + 2,
type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]})
outputs={'Out': [fill_constant_out]},
attrs={'op_role': op_role})
idx_offset = 3
# insert c_allgather op
......@@ -649,20 +716,21 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'nranks': group.nranks
'nranks': group.nranks,
'op_role': op_role
})
idx_offset += 1
# insert split op
split_out = _insert_split_op(block, idx + idx_offset, allgather_out,
group.nranks)
group.nranks, op_role)
idx_offset += 1
tensor_list.extend(split_out)
return tensor_list, idx_offset
def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
block, idx):
block, idx, op_role):
"""Concat the tensors and insert concat op."""
if not partition_tensor_list:
partition_tensor_list.append((tensor, partition_index))
......@@ -674,13 +742,13 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
partition_tensor_list[i][1], partition_index)
if concat_axis != -1:
has_concat = True
_ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis) \
_ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \
if first_order == 0 else \
_insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis)
_insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role)
partition_tensor_list.pop(i)
idx[0] += 1
_concat_partitions_with_op(partition_tensor_list, _,
new_partition, block, idx)
new_partition, block, idx, op_role)
break
i += 1
if not has_concat:
......@@ -692,8 +760,47 @@ HAS_RECV = {}
HAS_ALLGATHER = {}
def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
dist_context):
def _get_while_op_actual_process_mesh(op, program, rank_id, dist_context):
"""Get the while op actual Process mesh corresponding to rank"""
assert op.type == "while"
while_op_process_mesh = dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
sub_block = program.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue
if rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break
if actual_process_mesh is None and rank_id in while_op_process_mesh.processes:
actual_process_mesh = while_op_process_mesh
assert actual_process_mesh is not None
return actual_process_mesh
def _get_var(var_name, block, program):
"""Get var in the parent block if not found in the current block"""
var = None
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
return var
def parse_op_desc(block, rank_id, op_desc_seq, var_name, reshard_op,
dist_context, program, actual_process_mesh):
"""Parse op desc sequence and insert op in the block"""
global HAS_SENT
global HAS_RECV
......@@ -703,9 +810,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if rank_id not in op_desc_seq.keys():
return
op_desc_list = op_desc_seq[rank_id]
block = program.global_block()
assert var_name in block.vars.keys(
), "The {} cannot be found in the {} program.".format(var_name, rank_id)
idx = None
for index, op in list(enumerate(block.ops)):
......@@ -716,7 +820,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
rank_id)
matched_op = block.ops[idx]
source_tensor = block.vars[var_name]
source_tensor = _get_var(var_name, block, program)
for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc): # noqa: F401
if var_name not in HAS_ALLGATHER.keys():
......@@ -724,7 +828,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if not HAS_ALLGATHER[var_name] or op_desc.group not in list(
map(lambda x: x[0], HAS_ALLGATHER[var_name])):
tensor_list, idx_offset = _insert_allgather_op(
block, idx, source_tensor, op_desc.group)
block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list]
HAS_ALLGATHER[var_name].append(
......@@ -743,7 +848,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if var_name not in HAS_SENT.keys():
HAS_SENT[var_name] = []
if op_desc.dst not in HAS_SENT[var_name]:
_insert_send_op(block, idx, source_tensor, op_desc.dst)
_insert_send_op(block, idx, source_tensor, op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
HAS_SENT[var_name].append(op_desc.dst)
......@@ -758,8 +864,10 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
dtype=source_tensor.dtype)
_insert_recv_op(block, idx, recv_tensor, op_desc.src)
dtype=source_tensor.dtype,
type=source_tensor.type)
_insert_recv_op(block, idx, recv_tensor, op_desc.src,
reshard_op.attr('op_role'))
tensor_list.append(recv_tensor)
idx += 1
HAS_RECV[var_name][op_desc.src] = recv_tensor
......@@ -772,7 +880,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
for index, tensor in enumerate(tensor_list):
_concat_partitions_with_op(partition_tensor_list, tensor,
partition_index_list[index], block,
idx_list)
idx_list, reshard_op.attr('op_role'))
idx = idx_list[0]
elif isinstance(op_desc, SliceOpDesc):
......@@ -787,11 +895,11 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name)
new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
tensor_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_op_dist_attr_for_program(
matched_op).process_mesh
process_mesh = actual_process_mesh
dims_mapping = dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping
......@@ -799,11 +907,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
dist_context.set_tensor_dist_attr_for_program(target_tensor,
tensor_attr)
if op.type == "while":
global while_block_info
# var_reshard_mapping means the while op input need be changed to
if "var_reshard_mapping" not in while_block_info[op.attr(
"sub_block").id].keys():
while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"] = {}
while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = target_tensor.name
# rename op input name according to new name
for op in block.ops:
for name in op.input_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id():
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
......@@ -819,102 +945,166 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
not_remove_op_ref = [
"create_py_reader", "create_double_buffer_reader", "read"
]
remove_op_idx = []
block = auto_parallel_main_prog.global_block()
ops = block.ops
vars = block.vars
for idx, op in enumerate(ops):
# handle read op in the pipeline scene specially, it will be removed in the future.
if op.type == "read":
dim_list = []
for var_name in op.output_arg_names:
dim_list.extend(vars[var_name].shape)
for i in range(idx, -1, -1):
if ops[i].type == "create_py_reader":
ops[i]._set_attr("shape_concat", dim_list)
break
continue
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
if op.type == "c_sync_comm_stream":
need_save = []
for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
vars[var_name]).process_mesh
if rank_id in process_mesh.processes:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
global while_block_info
# NOTE: The nested sub block is not be supported now.
remove_block_order = []
for block_idx in while_block_info:
remove_block_order.append(block_idx)
for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
if block_idx not in remove_block_order:
remove_block_order.append(block_idx)
# the sub block should be removed first
for block_idx in remove_block_order:
remove_op_idx = []
block = auto_parallel_main_prog.blocks[block_idx]
ops = block.ops
vars = block.vars
for idx, op in enumerate(ops):
if op.type == "read":
dim_list = []
for var_name in op.output_arg_names:
dim_list.extend(
_get_var(var_name, block, auto_parallel_main_prog)
.shape)
for i in range(idx, -1, -1):
if ops[i].type == "create_py_reader":
ops[i]._set_attr("shape_concat", dim_list)
break
continue
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, need_save)
op.desc.set_output(proto.outputs[0].name, need_save)
continue
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
if op.type == "c_sync_comm_stream":
need_save = []
for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
_get_var(var_name, block,
auto_parallel_main_prog)).process_mesh
if rank_id in process_mesh.processes:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
continue
# judge the other op whether should be removed.
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is not None:
op_process_mesh = op_dist_attr.process_mesh
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
remove_op_idx.append(idx)
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, need_save)
op.desc.set_output(proto.outputs[0].name, need_save)
continue
for idx in remove_op_idx[::-1]:
block._remove_op(idx)
# judge the other op whether should be removed.
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is not None:
op_process_mesh = op_dist_attr.process_mesh
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
block._remove_op(idx)
def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads):
"""Remove no need vars in the main program"""
remove_vars = set()
block = auto_parallel_main_prog.global_block()
ops = block.ops
vars = block.vars
need_vars = set()
for op in ops:
for var_name in op.input_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var_name in op.output_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var in vars:
if var not in need_vars:
remove_vars.add(var)
# change dist_params_grads
param_grad_map = {}
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
need_remove_idx = []
for idx, item in enumerate(dist_params_grads):
if item[0].name not in param_grad_map.keys():
need_remove_idx.append(idx)
for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx)
idx = 0
while idx < len(dist_params_grads):
param_name = dist_params_grads[idx][0].name
grad_name = dist_params_grads[idx][1].name
if grad_name != param_grad_map[param_name]:
dist_params_grads[idx] = (vars[param_name],
vars[param_grad_map[param_name]])
idx += 1
for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
remove_vars = set()
ops = block.ops
vars = block.vars
need_vars = set()
for op in ops:
for var_name in op.input_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var_name in op.output_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var in vars:
if var not in need_vars:
remove_vars.add(var)
# change dist_params_grads, the optimize op just in block 0.
if block_idx == 0:
param_grad_map = {}
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
need_remove_idx = []
for idx, item in enumerate(dist_params_grads):
if item[0].name not in param_grad_map.keys():
need_remove_idx.append(idx)
for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx)
idx = 0
while idx < len(dist_params_grads):
param_name = dist_params_grads[idx][0].name
grad_name = dist_params_grads[idx][1].name
if grad_name != param_grad_map[param_name]:
dist_params_grads[idx] = (vars[param_name],
vars[param_grad_map[param_name]])
idx += 1
for var in remove_vars:
block._remove_var(var)
for var in remove_vars:
block._remove_var(var)
def _change_while_op_input_and_output(auto_parallel_main_prog, dist_context):
"""Change while op input and output after the corresponding sub block ops removed"""
global while_block_info
for sub_block_idx in while_block_info:
sub_block = auto_parallel_main_prog.blocks[sub_block_idx]
parent_while_op_id = while_block_info[sub_block_idx]["op_id"]
parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx]
sub_block_op_inputs = set()
sub_block_op_outputs = []
for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name)
for var_name in op.input_arg_names:
sub_block_op_inputs.add(var_name)
# find the while op
while_op = None
for op in parent_block.ops:
if op.desc.id() == parent_while_op_id and op.type == "while":
while_op = op
break
assert while_op is not None
# find the actual input and output of while op
proto = OpProtoHolder.instance().get_op_proto(while_op.type)
new_X = []
for var_name in while_op.input("X"):
if var_name in sub_block_op_inputs:
new_X.append(var_name)
assert new_X
while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads):
"""Remove no need vars and ops in the main program."""
_remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id)
_change_while_op_input_and_output(auto_parallel_main_prog, dist_context)
_remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
......@@ -992,8 +1182,70 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
startup_block._remove_op(idx)
def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
dist_context, dist_params_grads):
def _get_process_meshes(op, program, dist_context):
"""Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = program.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def _is_condition_replicative(op, program, dist_context):
assert op.type == "while"
sub_block = program.blocks[op.attr("sub_block").id]
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
# the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"):
var = _get_var(var_name, sub_block, program)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
tensor_dist_attr = dist_tensor.dist_attr
var_dims_mapping = tensor_dist_attr.dims_mapping
for dim in var_dims_mapping:
if dim != -1:
return False
return True
def _get_op_process_meshes(op, dist_context):
process_meshes = []
dist_op = dist_context.get_dist_op_for_program(op)
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in dist_context.process_meshes:
if set(process_mesh.processes) & (
set(op_process_mesh.processes)
) and len(process_mesh.processes) <= len(op_process_mesh.processes):
process_meshes.append(process_mesh)
# it means the process mesh is not a union when process meshes is null
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def reshard(auto_parallel_main_prog,
auto_parallel_startup_prog,
rank_id,
dist_context,
dist_params_grads,
batch_size=None):
"""
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
......@@ -1019,65 +1271,137 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
return True
return False
block = auto_parallel_main_prog.global_block()
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
global while_block_info
for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
if block_idx in while_block_info:
if "var_reshard_mapping" in while_block_info[block_idx]:
var_reshard_mapping = while_block_info[block_idx][
"var_reshard_mapping"]
for op in block.ops:
for var_name in op.input_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_input(var_name,
var_reshard_mapping[var_name])
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
op_dist_attr.set_input_dims_mapping(
var_reshard_mapping[var_name], dims_mapping)
op_dist_attr.set_input_dist_attr(var_name, None)
# the outputs also need to be renamed when the output name is the same with input name
for var_name in op.output_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_output(
var_name, var_reshard_mapping[var_name])
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
op_dist_attr.set_output_dims_mapping(
var_reshard_mapping[var_name], dims_mapping)
op_dist_attr.set_output_dist_attr(var_name,
None)
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
if _is_special_op(op):
idx += 1
continue
if _is_special_op(op):
idx += 1
continue
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None:
process_meshes = []
if op.type == "while":
if not _is_condition_replicative(
op, auto_parallel_main_prog, dist_context):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
)
process_meshes = _get_process_meshes(
op, auto_parallel_main_prog, dist_context)
assert process_meshes
if op.attr("sub_block").id not in while_block_info:
while_block_info[op.attr("sub_block").id] = {}
while_block_info[op.attr("sub_block").id][
"op_id"] = op.desc.id()
while_block_info[op.attr("sub_block").id][
"actual_process_mesh"] = _get_while_op_actual_process_mesh(
op, auto_parallel_main_prog, rank_id, dist_context)
else:
process_meshes = _get_op_process_meshes(op, dist_context)
input_vars = None
if op.type == "while":
input_var_names = op.input("X")
else:
input_var_names = op.input_arg_names
idx_offset = 0
for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = _get_var(var_name, block, auto_parallel_main_prog)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
for process_mesh in process_meshes:
if dist_tensor is not None and _need_reshard(
dist_tensor, dist_op, process_mesh,
auto_parallel_main_prog, dist_context):
reshard_op_desc = find_op_desc_seq(
dist_tensor, dist_op, process_mesh, batch_size)
parse_op_desc(block, rank_id, reshard_op_desc,
var_name, op, dist_context,
auto_parallel_main_prog, process_mesh)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None:
idx_offset = 0
for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = block.vars[var_name]
dist_tensor = dist_context.get_dist_tensor_for_program(var)
if dist_tensor is not None and _need_reshard(dist_tensor,
dist_op):
reshard_op_desc = find_op_desc_seq(dist_tensor, dist_op)
parse_op_desc(auto_parallel_main_prog, rank_id,
reshard_op_desc, var_name, op, dist_context)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
# insert send and recv op if output process mesh is different from tensor process mesh
idx = 0
skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
skip_ops += _g_special_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops:
for var_name in op.output_arg_names:
var = block.vars[var_name]
dist_tensor = dist_context.get_dist_tensor_for_program(var)
if dist_tensor is not None and _need_reshard(dist_tensor,
dist_op, False):
for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
if rank_id == item:
_insert_send_op(block, idx + 1, var, recv_rank)
if rank_id == recv_rank:
_insert_recv_op(block, idx + 1, var, item)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
# insert send and recv op if output process mesh is different from tensor process mesh
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while",
"write_to_array", "read_from_array"
]
skip_ops += _g_special_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops:
for var_name in op.output_arg_names:
var = _get_var(var_name, block, auto_parallel_main_prog)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
process_mesh = dist_op.dist_attr.process_mesh
if dist_tensor is not None and _need_reshard(
dist_tensor, dist_op, process_mesh,
auto_parallel_main_prog, dist_context, False):
for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
if rank_id == item:
_insert_send_op(block, idx + 1, var, recv_rank,
op.attr('op_role'))
if rank_id == recv_rank:
_insert_recv_op(block, idx + 1, var, item,
op.attr('op_role'))
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
# remove no need vars and ops in the main program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
......
......@@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program
from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program
from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static()
_global_parallel_strategy = None
......@@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase):
str(paddle.distributed.get_rank())))
def test_mlp_mp2pp(self):
set_default_distributed_context(None)
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
......@@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase):
fetch_list=[loss])
last_res = res[0]
set_default_distributed_context(None)
_global_parallel_strategy = "pp"
_global_process_mesh = auto.ProcessMesh([0, 1])
global PP_MESH_0
......@@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
str(paddle.distributed.get_rank())))
def test_mlp_pp2mp(self):
set_default_distributed_context(None)
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
......@@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
if paddle.distributed.get_rank() in [1]:
last_res = res[0]
set_default_distributed_context(None)
_global_parallel_strategy = "mp"
_global_process_mesh = auto.ProcessMesh([0, 1])
......@@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
np.random.seed(2021)
def test_input_invalid(self):
set_default_distributed_context(None)
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册