未验证 提交 d101334c 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard (#40865)

* fix code stype

* update unitest
上级 86554d91
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .reshard import reshard # noqa: F401 from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost from .cost_model import estimate_cost
__all__ = [] __all__ = []
...@@ -235,19 +235,19 @@ class Converter(object): ...@@ -235,19 +235,19 @@ class Converter(object):
@staticmethod @staticmethod
def merge_with_dist_attr(tensor_list, dist_attr): def merge_with_dist_attr(tensor_list, dist_attr):
""" Merge tensor with distributed attribute """ """ Merge tensor with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# get the complete shape of the tensor # get the complete shape of the tensor
complete_shape = _compute_complete_shape(tensor_list[0].shape, complete_shape = Resharder.compute_complete_shape(
process_shape, dims_mapping) tensor_list[0].shape, process_shape, dims_mapping)
# merge the tensor with dist_attr # merge the tensor with dist_attr
partition_tensor_list = [] partition_tensor_list = []
merged_partiton = [] merged_partiton = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process, complete_shape, dims_mapping, process_shape,
process_group) process_group)
index = process_group.index(process) index = process_group.index(process)
...@@ -302,7 +302,7 @@ class Converter(object): ...@@ -302,7 +302,7 @@ class Converter(object):
_merge_tensor(partition_tensor_list, tensor, partition_index) _merge_tensor(partition_tensor_list, tensor, partition_index)
# partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] # partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
""" """
from .reshard import _compute_concat_info from .reshard import Resharder
if len(partition_tensor_list) == 1: if len(partition_tensor_list) == 1:
is_complete_data = True is_complete_data = True
...@@ -318,7 +318,7 @@ class Converter(object): ...@@ -318,7 +318,7 @@ class Converter(object):
else: else:
i = 0 i = 0
while i < len(partition_tensor_list): while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i][1], partition_index) partition_tensor_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
if first_order == 0: if first_order == 0:
...@@ -391,11 +391,11 @@ class Converter(object): ...@@ -391,11 +391,11 @@ class Converter(object):
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]] # index: [[], [], [2, 4]]
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
split_indices_list = [] split_indices_list = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process, complete_shape, dims_mapping, process_shape,
process_group) process_group)
if split_indices_list: if split_indices_list:
...@@ -437,9 +437,9 @@ class Converter(object): ...@@ -437,9 +437,9 @@ class Converter(object):
process_shape, process_group) process_shape, process_group)
# index: 2 # index: 2
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
rank_id, complete_shape, dims_mapping, process_shape, process_group) rank_id, complete_shape, dims_mapping, process_shape, process_group)
sliced_index = 0 sliced_index = 0
for i, shape in enumerate(complete_shape): for i, shape in enumerate(complete_shape):
......
...@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger ...@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger
from .mapper import mapping from .mapper import mapping
from .cluster import Cluster from .cluster import Cluster
from .reshard import reshard from .reshard import Resharder
from .planner import Planner from .planner import Planner
from .completion import Completer from .completion import Completer
from .partitioner import Partitioner from .partitioner import Partitioner
...@@ -187,8 +187,9 @@ class Engine: ...@@ -187,8 +187,9 @@ class Engine:
# Do reshard process # Do reshard process
set_grad_var_shape(dist_main_prog, dist_context) set_grad_var_shape(dist_main_prog, dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# Apply post optimization passes # Apply post optimization passes
self._apply_post_optimization(dist_main_prog, dist_startup_prog, self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
...@@ -199,8 +200,9 @@ class Engine: ...@@ -199,8 +200,9 @@ class Engine:
serial_main_program, serial_startup_program, []) serial_main_program, serial_startup_program, [])
# Do reshard process # Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, [], resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
1) dist_context, [], 1)
resharder.reshard()
# clone program for test # clone program for test
if mode != 'train': if mode != 'train':
......
...@@ -42,7 +42,7 @@ from .utils import make_data_unshard ...@@ -42,7 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from .reshard import Resharder
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -213,17 +213,15 @@ class AutoParallelizer: ...@@ -213,17 +213,15 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_params_grads) self._dist_context, dist_params_grads)
resharder.reshard()
self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
g_process_group_map = None g_process_group_map = None
if not relaunch_phase: if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map) g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear() _g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes: for process_mesh in dist_context._process_meshes:
......
...@@ -29,7 +29,19 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map ...@@ -29,7 +29,19 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded. # NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
while_block_info = {}
def get_var_with_recursion(var_name, block, program):
"""Get var in the parent block if not found in the current block"""
var = None
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
return var
class AllGatherOpDesc: class AllGatherOpDesc:
...@@ -157,7 +169,7 @@ class ConcatOpDesc: ...@@ -157,7 +169,7 @@ class ConcatOpDesc:
Describe the concat op in the reshard phase. Describe the concat op in the reshard phase.
Args: Args:
partition_index_list (list): A list contains all partition index. partition_index_list (list): The list contains all partition index.
""" """
def __init__(self, partition_index_list): def __init__(self, partition_index_list):
...@@ -176,1135 +188,1202 @@ class ConcatOpDesc: ...@@ -176,1135 +188,1202 @@ class ConcatOpDesc:
return f"op: {self._desc}, partition_index_list: {self._partition_index_list}." return f"op: {self._desc}, partition_index_list: {self._partition_index_list}."
def _compute_partition_shape(complete_shape, dims_mapping, process_shape): class Inserter:
"""Compute the shape of partition.""" """Insert op required in the reshard process."""
partition_shape = []
for idx, item in enumerate(complete_shape):
if dims_mapping[idx] == -1:
partition_shape.append(item)
else:
partition_shape.append(item // process_shape[dims_mapping[idx]])
return partition_shape
def _compute_process_index(process, process_group, process_shape): @staticmethod
"""Compute the index of process_shape corresponding to the process.""" def insert_send_op(block, idx, tensor, dst, op_role):
relative_process = process_group.index(process) """Insert send op into block at the given index."""
process_index = [] op_type = 'send_v2'
product = reduce(lambda x, y: x * y, process_shape) block._insert_op(
idx,
type=op_type,
inputs={'X': [tensor]},
attrs={
'ring_id': 0,
'peer': dst,
'use_calc_stream': True,
'op_role': op_role
})
for i in range(len(process_shape)): @staticmethod
idx = relative_process // (product // process_shape[i]) def insert_recv_op(block, idx, tensor, src, op_role):
product = product // process_shape[i] """Insert recv op into block at the given index."""
relative_process = relative_process - relative_process // product * product op_type = 'recv_v2'
process_index.append(idx) block._insert_op(
idx,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': 0,
'peer': src,
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': True,
'op_role': op_role
})
return process_index @staticmethod
def insert_concat_op(block, idx, tensors, axis, op_role):
"""Insert concat op into block at the given block."""
inputs = {'X': tensors}
attrs = {}
attrs['axis'] = axis
attrs['op_role'] = op_role
helper = LayerHelper('concat', **locals())
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
block._insert_op(
idx,
type='concat',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
return out
@staticmethod
def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_role):
"""Insert slice op into block at the given block."""
inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
attrs = {
"axes": axes,
"starts": starts,
"ends": ends,
"infer_flags": infer_flags,
'op_role': op_role
}
helper = LayerHelper('slice', **locals())
out = block.create_var(
name=new_var_name, dtype=tensor.dtype, type=tensor.type)
block._insert_op(
idx,
type="slice",
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
return out
def _compute_partition_index(process, complete_shape, dims_mapping, @staticmethod
process_shape, process_group): def insert_split_op(block, idx, tensor, num_or_sections, op_role):
"""Compute the partition index in complete tensor.""" """Insert split op into block at the given index."""
partition_shape = _compute_partition_shape(complete_shape, dims_mapping, helper = LayerHelper('split', **locals())
process_shape) input_shape = tensor.shape
process_index = _compute_process_index(process, process_group, inputs = {'X': tensor}
process_shape) attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role}
partition_index = [] with paddle.static.program_guard(block.program):
outs = [
helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) for i in range(num_or_sections)
]
block._insert_op(
idx,
type="split",
inputs=inputs,
outputs={'Out': outs},
attrs=attrs)
return outs
for i in range(len(complete_shape)): @staticmethod
if dims_mapping[i] == -1: def insert_fill_constant_op(block, idx, op_role):
partition_index.append([0, partition_shape[i]]) """Insert fill constant op into block at the given index."""
else: helper = LayerHelper("fill_constant", **locals())
partition_index.append([ with paddle.static.program_guard(block.program):
process_index[dims_mapping[i]] * partition_shape[i], out = helper.create_variable_for_type_inference(dtype="int32")
(process_index[dims_mapping[i]] + 1) * partition_shape[i] inputs = {}
]) attrs = {'force_cpu': False}
attrs['str_value'] = str(int("1"))
attrs['value'] = int("1")
attrs['dtype'] = out.dtype
attrs['op_role'] = op_role
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant')
block._insert_op(
idx,
type='fill_constant',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True
return out
return partition_index @staticmethod
def insert_allgather_op(block, idx, tensor, ranks, op_role):
"""Insert allgather op into block at the given index."""
tensor_list = []
group = new_process_group(ranks)
idx_offset = 0
# instant process group before insert allgather op.
if not group.is_instantiate():
# insert fill_constant op
fill_constant_out = Inserter.insert_fill_constant_op(block, idx,
op_role)
fill_constant_out.stop_gradient = True
def _compute_concat_info(partition_index_x, partition_index_y): # insert c_allreduce_sum op
"""Judge whether two partition can be concatenated and compute concatenated partition index.""" block._insert_op(
differ_count = 0 idx + 1,
concat_axis = -1 type="c_allreduce_sum",
first_order = 0 inputs={'X': [fill_constant_out]},
new_partition = [] outputs={'Out': [fill_constant_out]},
attrs={
'ring_id': 0,
'use_calc_stream': True,
'op_role': op_role
})
for idx, item in enumerate(partition_index_x): # insert c_sync_calc_stream op
if item != partition_index_y[idx]: block._insert_op(
differ_count += 1 idx + 2,
if item[1] == partition_index_y[idx][0] and item[ type="c_sync_calc_stream",
0] < partition_index_y[idx][1]: inputs={'X': [fill_constant_out]},
concat_axis = idx outputs={'Out': [fill_constant_out]},
new_partition.append([item[0], partition_index_y[idx][1]]) attrs={'op_role': op_role})
elif item[0] == partition_index_y[idx][1] and item[ idx_offset = 3
1] > partition_index_y[idx][0]:
first_order = 1
concat_axis = idx
new_partition.append([partition_index_y[idx][0], item[1]])
else:
new_partition.append(item)
if differ_count == 1: # insert c_allgather op
return concat_axis, first_order, new_partition op_type = 'c_allgather'
else: helper = LayerHelper(op_type, **locals())
return -1, first_order, new_partition with paddle.static.program_guard(block.program):
allgather_out = helper.create_variable_for_type_inference(
dtype=tensor.dtype)
block._insert_op(
idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [allgather_out]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'nranks': group.nranks,
'op_role': op_role
})
idx_offset += 1
# insert split op
split_out = Inserter.insert_split_op(
block, idx + idx_offset, allgather_out, group.nranks, op_role)
idx_offset += 1
tensor_list.extend(split_out)
return tensor_list, idx_offset
def _concat_partitions(partition_index_list, partition_index): @staticmethod
"""Concat the given partitions without inserting concat op.""" def concat_partitions_with_op(partition_tensor_list, tensor,
if not partition_index_list: partition_index, block, idx, op_role):
partition_index_list.append(partition_index) """Concat the tensors and insert concat op."""
if not partition_tensor_list:
partition_tensor_list.append((tensor, partition_index))
else: else:
i = 0 i = 0
has_concat = False has_concat = False
while i < len(partition_index_list): while i < len(partition_tensor_list):
concat_axis, _, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_index_list[i], partition_index) partition_tensor_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
has_concat = True has_concat = True
partition_index_list.pop(i) _ = Inserter.insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \
_concat_partitions(partition_index_list, new_partition) if first_order == 0 else \
Inserter.insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role)
partition_tensor_list.pop(i)
idx[0] += 1
Inserter.concat_partitions_with_op(partition_tensor_list, _,
new_partition, block,
idx, op_role)
break break
i += 1 i += 1
if not has_concat: if not has_concat:
partition_index_list.append(partition_index) partition_tensor_list.append((tensor, partition_index))
def _is_overlapped(shape_x, shape_y): class Remover:
"""Judge whether two partitions intersect on the specified dimension.""" """Remove var and op in the reshard process."""
overlapped = False
if (shape_y[0] <= shape_x[0] < shape_y[1]) or (
shape_x[0] <= shape_y[0] < shape_x[1]):
overlapped = True
return overlapped
@staticmethod
def remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
"""Remove no need ops in the main program"""
not_remove_op_ref = [
"create_py_reader", "create_double_buffer_reader", "read"
]
def _need_reshard(dist_tensor, # NOTE: The nested sub block is not be supported now.
dist_op, remove_block_order = []
actual_process_mesh, for block_idx in Resharder.while_block_info:
program, remove_block_order.append(block_idx)
dist_context,
op_input=True):
"""Judge the tensor whether needs to be resharded."""
def _is_unshard(dims_mapping): for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
for dim in dims_mapping: if block_idx not in remove_block_order:
if dim != -1: remove_block_order.append(block_idx)
return False
return True
is_reshard = False # the sub block should be removed first
tensor_dist_attr = dist_tensor.dist_attr for block_idx in remove_block_order:
tensor_name = dist_tensor.serial_tensor.name remove_op_idx = []
tensor_dims_mapping = tensor_dist_attr.dims_mapping block = auto_parallel_main_prog.blocks[block_idx]
tensor_process_mesh = tensor_dist_attr.process_mesh ops = block.ops
op_dist_attr = dist_op.dist_attr vars = block.vars
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) for idx, op in enumerate(ops):
op_process_mesh = actual_process_mesh if op.type == "read":
if op_input: dim_list = []
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) for var_name in op.output_arg_names:
if all( dim_list.extend(
map(lambda x: x is not None, [ get_var_with_recursion(
tensor_dims_mapping, tensor_process_mesh, var_name, block, auto_parallel_main_prog).shape)
op_input_dims_mapping, op_process_mesh for i in range(idx, -1, -1):
])): if ops[i].type == "create_py_reader":
# dims_mapping ops[i]._set_attr("shape_concat", dim_list)
if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = program.blocks[dist_op.serial_op.attr(
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break break
else: continue
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert _is_unshard(tensor_dims_mapping)
assert _is_unshard(op_input_dims_mapping)
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
# for while op, it should find the process mesh of op actually used the tensor as input # replace the input and output of c_sync_comm_stream op when in pipeline scene.
if dist_op.serial_op.type == "while": if op.type == "c_sync_comm_stream":
sub_block = program.blocks[dist_op.serial_op.attr( need_save = []
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
if var_name == tensor_name: process_mesh = dist_context.get_tensor_dist_attr_for_program(
dist_op_attr = dist_context.get_dist_op_for_program( get_var_with_recursion(
op).dist_attr var_name, block,
process_mesh = dist_op_attr.process_mesh auto_parallel_main_prog)).process_mesh
if process_mesh == op_process_mesh: if rank_id in process_mesh.processes:
is_reshard = True need_save.append(var_name)
break if not need_save:
else: remove_op_idx.append(idx)
is_reshard = True continue
else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh
])):
if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
return is_reshard
def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
"""compute the complete shape of the slice tensor with its process mesh and dims mapping"""
complete_shape = []
for idx, item in enumerate(slice_shape):
if dims_mapping[idx] == -1:
complete_shape.append(item)
else:
complete_shape.append(item * process_shape[dims_mapping[idx]])
return complete_shape
def find_op_desc_seq(dist_tensor, dist_op, actual_process_mesh, batch_size):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
process and value is a list containing op description.
"""
tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr
target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0:
new_shape = list(source_tensor.shape)
new_shape[0] = batch_size
source_tensor.desc.set_shape(new_shape)
complete_shape = _compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping)
op_desc_seq = {}
# TODO: if the target process group has the same process with source process group
if set(target_process_group).intersection(set(
source_process_group)) and set(target_process_group).difference(
set(source_process_group)):
pass
# in the different process group, it will use send, recv, concat and slice op
elif target_process_group != source_process_group:
partition_process_mapping_list = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(source_process, complete_shape, source_dims_mapping, \
source_process_shape, source_process_group)
if not partition_process_mapping_list:
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
else:
partition_list = list(
[item[0] for item in partition_process_mapping_list])
process_list = list(
[item[1] for item in partition_process_mapping_list])
has_used = list(
[item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index)
process_list[index].append(source_process)
has_used[index].append(False)
else:
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
for target_process in target_process_group:
has_sent = []
target_partition_index = _compute_partition_index(
target_process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
partition_index_list = []
all_partition_index_list = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
to_send_process = None
if all(_ for _ in list(map(_is_overlapped, source_partition_index, target_partition_index))) \
and source_partition_index not in has_sent:
idx = list([
item[0] for item in partition_process_mapping_list
]).index(source_partition_index)
has_used = list(
[item[2]
for item in partition_process_mapping_list])[idx]
process_list = list(
[item[1]
for item in partition_process_mapping_list])[idx]
i = 0
while i < len(has_used):
if not has_used[i]:
to_send_process = process_list[i]
has_used[i] = True
break
i += 1
if i == len(has_used):
has_used = list(map(lambda x: False, has_used))
to_send_process = process_list[0]
has_used[0] = True
assert to_send_process is not None, "Failed to find the send process."
if to_send_process not in op_desc_seq.keys():
op_desc_seq[to_send_process] = []
if target_process not in op_desc_seq.keys():
op_desc_seq[target_process] = []
all_partition_index_list.append(source_partition_index)
# append send and recv op desc
send_op_desc = SendOpDesc(source_partition_index,
target_process)
recv_op_desc = RecvOpDesc(source_partition_index,
to_send_process)
op_desc_seq[to_send_process].append(send_op_desc)
op_desc_seq[target_process].append(recv_op_desc)
has_sent.append(source_partition_index)
_concat_partitions(partition_index_list,
source_partition_index)
# append concat op desc
op_desc_seq[target_process].append(
ConcatOpDesc(all_partition_index_list))
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
concatenated_partition_index = partition_index_list[0]
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] - item[0])
slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx)
op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes))
# in the same process group, it will use allgahther and slice op
else:
partition_index_list = []
all_partition_index_list = []
process_index = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
if source_partition_index not in partition_index_list:
partition_index_list.append(source_partition_index)
process_index.append(
[[source_process, ], source_partition_index])
else:
process_index[partition_index_list.index(
source_partition_index)][0].append(source_process)
for i in range(len(process_index[0][0])):
group = []
for j in range(len(process_index)):
group.append(process_index[j][0][i])
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
target_partition_index = _compute_partition_index(
process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
slices_axes.append(idx)
slice_op_desc = SliceOpDesc( proto = OpProtoHolder.instance().get_op_proto(op.type)
starts=slice_starts, ends=slice_ends, axes=slices_axes) op.desc.set_input(proto.inputs[0].name, need_save)
op_desc_seq[process] = [AllGatherOpDesc(group=group), op.desc.set_output(proto.outputs[0].name, need_save)
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \ continue
if len(group) > 1 else [slice_op_desc]
return op_desc_seq # judge the other op whether should be removed.
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is not None:
op_process_mesh = op_dist_attr.process_mesh
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
block._remove_op(idx)
def _insert_send_op(block, idx, tensor, dst, op_role): @staticmethod
"""Insert send op into block at the given index.""" def remove_no_need_vars(auto_parallel_main_prog, dist_params_grads):
op_type = 'send_v2' """Remove no need vars in the main program"""
block._insert_op( for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
idx, remove_vars = set()
type=op_type, ops = block.ops
inputs={'X': [tensor]}, vars = block.vars
attrs={ need_vars = set()
'ring_id': 0, for op in ops:
'peer': dst, for var_name in op.input_arg_names:
'use_calc_stream': True, if var_name in vars:
'op_role': op_role need_vars.add(var_name)
}) for var_name in op.output_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var in vars:
if var not in need_vars:
remove_vars.add(var)
# change dist_params_grads, the optimize op just in block 0.
if block_idx == 0:
param_grad_map = {}
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
def _insert_recv_op(block, idx, tensor, src, op_role): need_remove_idx = []
"""Insert recv op into block at the given index.""" for idx, item in enumerate(dist_params_grads):
op_type = 'recv_v2' if item[0].name not in param_grad_map.keys():
block._insert_op( need_remove_idx.append(idx)
idx,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': 0,
'peer': src,
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': True,
'op_role': op_role
})
for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx)
def _insert_concat_op(block, idx, tensors, axis, op_role): idx = 0
"""Insert concat op into block at the given block.""" while idx < len(dist_params_grads):
inputs = {'X': tensors} param_name = dist_params_grads[idx][0].name
attrs = {} grad_name = dist_params_grads[idx][1].name
attrs['axis'] = axis if grad_name != param_grad_map[param_name]:
attrs['op_role'] = op_role dist_params_grads[idx] = (
helper = LayerHelper('concat', **locals()) vars[param_name], vars[param_grad_map[param_name]])
with paddle.static.program_guard(block.program): idx += 1
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
block._insert_op(
idx, type='concat', inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
for var in remove_vars:
block._remove_var(var)
def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name, @staticmethod
op_role): def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
"""Insert slice op into block at the given block.""" dist_params_grads):
inputs = {'Input': tensor} """Remove no need vars and ops in the main program."""
infer_flags = list(1 for i in range(len(axes))) Remover.remove_no_need_ops(auto_parallel_main_prog, dist_context,
attrs = { rank_id)
"axes": axes, Resharder.change_while_op_input_and_output(auto_parallel_main_prog,
"starts": starts, dist_context)
"ends": ends, Remover.remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
"infer_flags": infer_flags,
'op_role': op_role
}
helper = LayerHelper('slice', **locals())
out = block.create_var(
name=new_var_name, dtype=tensor.dtype, type=tensor.type)
block._insert_op(
idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
@staticmethod
def remove_no_need_in_startup(auto_parallel_main_prog,
auto_parallel_startup_prog):
"""Remove no need vars and ops in the startup program."""
main_input_vars = set()
main_ops = auto_parallel_main_prog.global_block().ops
for op in main_ops:
for var_name in op.input_arg_names:
main_input_vars.add(var_name)
def _insert_split_op(block, idx, tensor, num_or_sections, op_role): startup_block = auto_parallel_startup_prog.global_block()
"""Insert split op into block at the given index.""" startup_output_vars = set()
helper = LayerHelper('split', **locals()) startup_ops = startup_block.ops
input_shape = tensor.shape for op in startup_ops:
inputs = {'X': tensor} # skip c_sync_comm_stream op
attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role} if op.type == "c_sync_comm_stream":
with paddle.static.program_guard(block.program): continue
outs = [ for var_name in op.output_arg_names:
helper.create_variable_for_type_inference( startup_output_vars.add(var_name)
dtype=helper.input_dtype()) for i in range(num_or_sections)
]
block._insert_op(
idx, type="split", inputs=inputs, outputs={'Out': outs}, attrs=attrs)
return outs
need_vars = set()
for var_name in startup_output_vars:
if var_name in main_input_vars:
need_vars.add(var_name)
def _insert_allgather_op(block, idx, tensor, ranks, op_role): startup_ops = startup_block.ops
"""Insert allgather op into block at the given index.""" actual_need_vars = set()
for idx, op in enumerate(startup_ops):
is_need_op = False
if op.type == "c_sync_comm_stream":
continue
for var_name in op.output_arg_names:
if var_name in need_vars:
is_need_op = True
break
if is_need_op:
for var_name in op.output_arg_names:
actual_need_vars.add(var_name)
for var_name in op.input_arg_names:
actual_need_vars.add(var_name)
def _insert_fill_constant_op(block, idx): remove_vars = set()
"""Insert fill constant op into block at the given index.""" for var_name in startup_block.vars:
helper = LayerHelper("fill_constant", **locals()) if var_name not in actual_need_vars:
with paddle.static.program_guard(block.program): remove_vars.add(var_name)
out = helper.create_variable_for_type_inference(dtype="int32") for var in remove_vars:
inputs = {} startup_block._remove_var(var)
attrs = {'force_cpu': False}
attrs['str_value'] = str(int("1"))
attrs['value'] = int("1")
attrs['dtype'] = out.dtype
attrs['op_role'] = op_role
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant')
block._insert_op(
idx,
type='fill_constant',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True
return out
tensor_list = [] remove_op_idx = []
group = new_process_group(ranks) vars = startup_block.vars
idx_offset = 0 for idx, op in enumerate(startup_block.ops):
is_no_need_op = False
if op.type == "c_sync_comm_stream":
var_names = []
for var_name in op.input_arg_names:
if var_name in vars:
var_names.append(var_name)
if not var_names:
remove_op_idx.append(idx)
else:
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, var_names)
op.desc.set_output(proto.outputs[0].name, var_names)
continue
# instant process group before insert allgather op. for var_name in op.output_arg_names:
if not group.is_instantiate(): if var_name not in vars:
# insert fill_constant op is_no_need_op = True
fill_constant_out = _insert_fill_constant_op(block, idx) break
fill_constant_out.stop_gradient = True if is_no_need_op:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
startup_block._remove_op(idx)
# insert c_allreduce_sum op
block._insert_op(
idx + 1,
type="c_allreduce_sum",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'ring_id': 0,
'use_calc_stream': True,
'op_role': op_role})
# insert c_sync_calc_stream op class Resharder:
block._insert_op( """
idx + 2, Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'op_role': op_role})
idx_offset = 3
# insert c_allgather op Args:
op_type = 'c_allgather' auto_parallel_main_prog (Program): An auto parallel main program.
helper = LayerHelper(op_type, **locals()) auto_parallel_startup_prog (Program): An auto parallel startup program.
with paddle.static.program_guard(block.program): rank_id (int): The process id.
allgather_out = helper.create_variable_for_type_inference( dist_context (DistributedContext): The distributed context of this rank.
dtype=tensor.dtype) dist_params_grads (list): The list contains the tuple of param and grad.
block._insert_op( batch_size (int): The batch size. Default: None.
idx + idx_offset, """
type=op_type, while_block_info = {}
inputs={'X': [tensor]},
outputs={'Out': [allgather_out]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'nranks': group.nranks,
'op_role': op_role
})
idx_offset += 1
# insert split op def __init__(self,
split_out = _insert_split_op(block, idx + idx_offset, allgather_out, auto_parallel_main_prog,
group.nranks, op_role) auto_parallel_startup_prog,
idx_offset += 1 rank_id,
tensor_list.extend(split_out) dist_context,
return tensor_list, idx_offset dist_params_grads,
batch_size=None):
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog))
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \
"but got {}.".format(type(auto_parallel_startup_prog))
assert isinstance(rank_id, int), "The type of rank_id should be int, " \
"but got {}.".format(type(rank_id))
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
if batch_size is not None:
assert isinstance(batch_size, int), "The type of batch_size should be int, " \
"but got {}.".format(type(batch_size))
self._auto_parallel_main_prog = auto_parallel_main_prog
self._auto_parallel_startup_prog = auto_parallel_startup_prog
self._rank_id = rank_id
self._dist_context = dist_context
self._dist_params_grads = dist_params_grads
self._batch_size = batch_size
self._has_sent = {}
self._has_recv = {}
self._has_allgather = {}
@property
def auto_parallel_main_prog(self):
return self._auto_parallel_main_prog
def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, @property
block, idx, op_role): def auto_parallel_startup_prog(self):
"""Concat the tensors and insert concat op.""" return self._auto_parallel_startup_prog
if not partition_tensor_list:
partition_tensor_list.append((tensor, partition_index))
else:
i = 0
has_concat = False
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info(
partition_tensor_list[i][1], partition_index)
if concat_axis != -1:
has_concat = True
_ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \
if first_order == 0 else \
_insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role)
partition_tensor_list.pop(i)
idx[0] += 1
_concat_partitions_with_op(partition_tensor_list, _,
new_partition, block, idx, op_role)
break
i += 1
if not has_concat:
partition_tensor_list.append((tensor, partition_index))
@property
def rank_id(self):
return self._rank_id
HAS_SENT = {} @property
HAS_RECV = {} def dist_context(self):
HAS_ALLGATHER = {} return self._dist_context
@property
def dist_params_grads(self):
return self._dist_params_grads
def _get_while_op_actual_process_mesh(op, program, rank_id, dist_context): @property
"""Get the while op actual Process mesh corresponding to rank""" def batch_size(self):
assert op.type == "while" return self._batch_size
while_op_process_mesh = dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
sub_block = program.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue
if rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break
if actual_process_mesh is None and rank_id in while_op_process_mesh.processes: @property
actual_process_mesh = while_op_process_mesh def has_sent(self):
return self._has_sent
assert actual_process_mesh is not None @property
return actual_process_mesh def has_recv(self):
return self._has_recv
@property
def has_allgather(self):
return self._has_allgather
def _get_var(var_name, block, program): @staticmethod
"""Get var in the parent block if not found in the current block""" def compute_partition_shape(complete_shape, dims_mapping, process_shape):
var = None """Compute the shape of partition."""
if var_name in block.vars: partition_shape = []
var = block.vars[var_name] for idx, item in enumerate(complete_shape):
if dims_mapping[idx] == -1:
partition_shape.append(item)
else: else:
parent_block = program.blocks[block.parent_idx] partition_shape.append(item // process_shape[dims_mapping[idx]])
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
return var
return partition_shape
def parse_op_desc(block, rank_id, op_desc_seq, var_name, reshard_op, @staticmethod
dist_context, program, actual_process_mesh): def compute_process_index(process, process_group, process_shape):
"""Parse op desc sequence and insert op in the block""" """Compute the index of process_shape corresponding to the process."""
global HAS_SENT relative_process = process_group.index(process)
global HAS_RECV process_index = []
global HAS_ALLGATHER product = reduce(lambda x, y: x * y, process_shape)
tensor_list = []
partition_tensor_list = []
if rank_id not in op_desc_seq.keys():
return
op_desc_list = op_desc_seq[rank_id]
idx = None for i in range(len(process_shape)):
for index, op in list(enumerate(block.ops)): idx = relative_process // (product // process_shape[i])
if op.desc.id == reshard_op.desc.id: product = product // process_shape[i]
idx = index relative_process = relative_process - relative_process // product * product
break process_index.append(idx)
assert idx is not None, "The op for reshard cannot be found in the rank {} program.".format(
rank_id)
matched_op = block.ops[idx] return process_index
source_tensor = _get_var(var_name, block, program)
for op_desc in op_desc_list: @staticmethod
if isinstance(op_desc, AllGatherOpDesc): # noqa: F401 def compute_partition_index(process, complete_shape, dims_mapping,
if var_name not in HAS_ALLGATHER.keys(): process_shape, process_group):
HAS_ALLGATHER[var_name] = [] """Compute the partition index in complete tensor."""
if not HAS_ALLGATHER[var_name] or op_desc.group not in list( partition_shape = Resharder.compute_partition_shape(
map(lambda x: x[0], HAS_ALLGATHER[var_name])): complete_shape, dims_mapping, process_shape)
tensor_list, idx_offset = _insert_allgather_op( process_index = Resharder.compute_process_index(process, process_group,
block, idx, source_tensor, op_desc.group, process_shape)
reshard_op.attr('op_role')) partition_index = []
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list] for i in range(len(complete_shape)):
HAS_ALLGATHER[var_name].append( if dims_mapping[i] == -1:
[op_desc.group, tensor_name_list]) partition_index.append([0, partition_shape[i]])
else: else:
for item in HAS_ALLGATHER[var_name]: partition_index.append([
if op_desc.group == item[0]: process_index[dims_mapping[i]] * partition_shape[i],
tensor_list = [ (process_index[dims_mapping[i]] + 1) * partition_shape[i]
program.global_block().vars[var_name] ])
for var_name in item[1]
]
break
assert tensor_list, "The result of parsing allgather op should not be None."
elif isinstance(op_desc, SendOpDesc): return partition_index
if var_name not in HAS_SENT.keys():
HAS_SENT[var_name] = []
if op_desc.dst not in HAS_SENT[var_name]:
_insert_send_op(block, idx, source_tensor, op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
HAS_SENT[var_name].append(op_desc.dst)
elif isinstance(op_desc, RecvOpDesc): @staticmethod
if var_name not in HAS_RECV.keys(): def compute_concat_info(partition_index_x, partition_index_y):
HAS_RECV[var_name] = {} """Judge whether two partition can be concatenated and compute concatenated partition index."""
if op_desc.src not in HAS_RECV[var_name].keys(): differ_count = 0
partition_index = op_desc.partition_index concat_axis = -1
shape = [] first_order = 0
for index in partition_index: new_partition = []
shape.append(index[1] - index[0])
recv_tensor = block.create_var( for idx, item in enumerate(partition_index_x):
name=unique_name.generate(var_name + "@recv"), if item != partition_index_y[idx]:
shape=shape, differ_count += 1
dtype=source_tensor.dtype, if item[1] == partition_index_y[idx][0] and item[
type=source_tensor.type) 0] < partition_index_y[idx][1]:
_insert_recv_op(block, idx, recv_tensor, op_desc.src, concat_axis = idx
reshard_op.attr('op_role')) new_partition.append([item[0], partition_index_y[idx][1]])
tensor_list.append(recv_tensor) elif item[0] == partition_index_y[idx][1] and item[
idx += 1 1] > partition_index_y[idx][0]:
HAS_RECV[var_name][op_desc.src] = recv_tensor first_order = 1
concat_axis = idx
new_partition.append([partition_index_y[idx][0], item[1]])
else: else:
tensor_list.append(HAS_RECV[var_name][op_desc.src]) new_partition.append(item)
elif isinstance(op_desc, ConcatOpDesc): if differ_count == 1:
partition_index_list = op_desc.partition_index_list return concat_axis, first_order, new_partition
idx_list = [idx] else:
for index, tensor in enumerate(tensor_list): return -1, first_order, new_partition
_concat_partitions_with_op(partition_tensor_list, tensor,
partition_index_list[index], block,
idx_list, reshard_op.attr('op_role'))
idx = idx_list[0]
elif isinstance(op_desc, SliceOpDesc): @staticmethod
assert len(partition_tensor_list) == 1 or not partition_tensor_list def compute_complete_shape(slice_shape, process_shape, dims_mapping):
to_slice_tensor = partition_tensor_list[0][0] if len( """compute the complete shape of the slice tensor with its process mesh and dims mapping"""
partition_tensor_list) == 1 else source_tensor complete_shape = []
new_name = unique_name.generate(var_name + "@RESHARD") for idx, item in enumerate(slice_shape):
target_tensor = _insert_slice_op( if dims_mapping[idx] == -1:
block, complete_shape.append(item)
idx, else:
to_slice_tensor, complete_shape.append(item * process_shape[dims_mapping[idx]])
starts=op_desc.starts, return complete_shape
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
tensor_attr = TensorDistributedAttribute() @staticmethod
process_mesh = actual_process_mesh def concat_partitions(partition_index_list, partition_index):
dims_mapping = dist_context.get_op_dist_attr_for_program( """Concat the given partitions without inserting concat op."""
matched_op).get_input_dims_mapping(var_name) if not partition_index_list:
tensor_attr.dims_mapping = dims_mapping partition_index_list.append(partition_index)
tensor_attr.process_mesh = process_mesh else:
dist_context.set_tensor_dist_attr_for_program(target_tensor, i = 0
tensor_attr) has_concat = False
while i < len(partition_index_list):
concat_axis, _, new_partition = Resharder.compute_concat_info(
partition_index_list[i], partition_index)
if concat_axis != -1:
has_concat = True
partition_index_list.pop(i)
Resharder.concat_partitions(partition_index_list,
new_partition)
break
i += 1
if not has_concat:
partition_index_list.append(partition_index)
if op.type == "while": @staticmethod
global while_block_info def change_while_op_input_and_output(auto_parallel_main_prog, dist_context):
# var_reshard_mapping means the while op input need be changed to """Change while op input and output after the corresponding sub block ops removed"""
if "var_reshard_mapping" not in while_block_info[op.attr( for sub_block_idx in Resharder.while_block_info:
"sub_block").id].keys(): sub_block = auto_parallel_main_prog.blocks[sub_block_idx]
while_block_info[op.attr("sub_block").id][ parent_while_op_id = Resharder.while_block_info[sub_block_idx][
"var_reshard_mapping"] = {} "op_id"]
while_block_info[op.attr("sub_block").id][ parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx]
"var_reshard_mapping"][var_name] = target_tensor.name
# rename op input name according to new name sub_block_op_inputs = set()
for op in block.ops: sub_block_op_outputs = []
for name in op.input_arg_names: for op in sub_block.ops:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) # skip the input and output of operators inserted in the reshard phase
if name == var_name and op_dist_attr is not None: dist_op = dist_context.get_dist_op_for_program(op)
if op.desc.id() == matched_op.desc.id(): if dist_op:
op.desc._rename_input(name, target_tensor.name) for var_name in op.output_arg_names:
op_dist_attr.set_input_dims_mapping( if var_name not in sub_block_op_outputs:
target_tensor.name, dims_mapping) sub_block_op_outputs.append(var_name)
op_dist_attr.set_input_dist_attr(name, None) for var_name in op.input_arg_names:
continue sub_block_op_inputs.add(var_name)
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation. # find the while op
op_process_mesh = op_dist_attr.process_mesh while_op = None
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( for op in parent_block.ops:
var_name) if op.desc.id() == parent_while_op_id and op.type == "while":
if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping: while_op = op
op.desc._rename_input(name, target_tensor.name) break
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
assert while_op is not None
def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): # find the actual input and output of while op
"""Remove no need ops in the main program""" proto = OpProtoHolder.instance().get_op_proto(while_op.type)
not_remove_op_ref = [ new_X = []
"create_py_reader", "create_double_buffer_reader", "read" for var_name in while_op.input("X"):
] if var_name in sub_block_op_inputs:
global while_block_info new_X.append(var_name)
assert new_X
while_op.desc.set_input(proto.inputs[0].name, new_X)
# NOTE: The nested sub block is not be supported now. new_Out = []
remove_block_order = [] for var_name in while_op.output("Out"):
for block_idx in while_block_info: for output_name in sub_block_op_outputs[::-1]:
remove_block_order.append(block_idx) if output_name.find(var_name) != -1:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
for block_idx, block in enumerate(auto_parallel_main_prog.blocks): def is_overlapped(self, shape_x, shape_y):
if block_idx not in remove_block_order: """Judge whether two partitions intersect on the specified dimension."""
remove_block_order.append(block_idx) overlapped = False
if (shape_y[0] <= shape_x[0] < shape_y[1]) or (
shape_x[0] <= shape_y[0] < shape_x[1]):
overlapped = True
return overlapped
# the sub block should be removed first def is_unshard(self, dims_mapping):
for block_idx in remove_block_order: for dim in dims_mapping:
remove_op_idx = [] if dim != -1:
block = auto_parallel_main_prog.blocks[block_idx] return False
ops = block.ops return True
vars = block.vars
for idx, op in enumerate(ops):
if op.type == "read":
dim_list = []
for var_name in op.output_arg_names:
dim_list.extend(
_get_var(var_name, block, auto_parallel_main_prog)
.shape)
for i in range(idx, -1, -1):
if ops[i].type == "create_py_reader":
ops[i]._set_attr("shape_concat", dim_list)
break
continue
# replace the input and output of c_sync_comm_stream op when in pipeline scene. def is_special_op(self, op):
if op.type == "c_sync_comm_stream": global _g_special_ops
need_save = [] if op.type in _g_special_ops:
for var_name in op.input_arg_names: return True
process_mesh = dist_context.get_tensor_dist_attr_for_program( return False
_get_var(var_name, block,
auto_parallel_main_prog)).process_mesh
if rank_id in process_mesh.processes:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
continue
proto = OpProtoHolder.instance().get_op_proto(op.type) def is_condition_replicative(self, op):
op.desc.set_input(proto.inputs[0].name, need_save) assert op.type == "while"
op.desc.set_output(proto.outputs[0].name, need_save) sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
continue dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
# judge the other op whether should be removed. # the dims mapping of condition tensor should be replicative
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) for var_name in op.input("Condition"):
if op_dist_attr is not None: var = get_var_with_recursion(var_name, sub_block,
op_process_mesh = op_dist_attr.process_mesh self.auto_parallel_main_prog)
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref: dist_tensor = self.dist_context.get_dist_tensor_for_program(var)
remove_op_idx.append(idx) tensor_dist_attr = dist_tensor.dist_attr
var_dims_mapping = tensor_dist_attr.dims_mapping
for dim in var_dims_mapping:
if dim != -1:
return False
for idx in remove_op_idx[::-1]: return True
block._remove_op(idx)
def need_reshard(self,
dist_tensor,
dist_op,
actual_process_mesh,
op_input=True):
"""Judge the tensor whether needs to be resharded."""
is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name
tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = actual_process_mesh
if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name)
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_input_dims_mapping, op_process_mesh
])):
# dims_mapping
if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert self.is_unshard(tensor_dims_mapping)
assert self.is_unshard(op_input_dims_mapping)
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError(
"Bool var is not supported reshard.")
def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads): # for while op, it should find the process mesh of op actually used the tensor as input
"""Remove no need vars in the main program""" if dist_op.serial_op.type == "while":
for block_idx, block in enumerate(auto_parallel_main_prog.blocks): sub_block = self.auto_parallel_main_prog.blocks[
remove_vars = set() dist_op.serial_op.attr("sub_block").id]
ops = block.ops for op in sub_block.ops:
vars = block.vars
need_vars = set()
for op in ops:
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
if var_name in vars: if var_name == tensor_name:
need_vars.add(var_name) dist_op_attr = self.dist_context.get_dist_op_for_program(
for var_name in op.output_arg_names: op).dist_attr
if var_name in vars: process_mesh = dist_op_attr.process_mesh
need_vars.add(var_name) if process_mesh == op_process_mesh:
for var in vars: is_reshard = True
if var not in need_vars: break
remove_vars.add(var) else:
is_reshard = True
# change dist_params_grads, the optimize op just in block 0. else:
if block_idx == 0: op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
param_grad_map = {} tensor_name)
for op in ops: if all(
if int(op.attr('op_role')) == int(OpRole.Optimize): map(lambda x: x is not None, [
if "Param" in op.input_names and "Grad" in op.input_names: tensor_dims_mapping, tensor_process_mesh,
param_name = op.input("Param")[0] op_output_dims_mapping, op_process_mesh
grad_name = op.input("Grad")[0] ])):
param_grad_map[param_name] = grad_name if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
need_remove_idx = [] return is_reshard
for idx, item in enumerate(dist_params_grads):
if item[0].name not in param_grad_map.keys():
need_remove_idx.append(idx)
for idx in need_remove_idx[::-1]: def get_process_meshes(self, op):
dist_params_grads.pop(idx) """Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
idx = 0 if not process_meshes:
while idx < len(dist_params_grads): process_meshes.append(op_process_mesh)
param_name = dist_params_grads[idx][0].name
grad_name = dist_params_grads[idx][1].name
if grad_name != param_grad_map[param_name]:
dist_params_grads[idx] = (vars[param_name],
vars[param_grad_map[param_name]])
idx += 1
for var in remove_vars: return process_meshes
block._remove_var(var)
def get_op_process_meshes(self, op):
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (
set(op_process_mesh.processes)
) and len(process_mesh.processes) <= len(op_process_mesh.processes):
process_meshes.append(process_mesh)
def _change_while_op_input_and_output(auto_parallel_main_prog, dist_context): # it means the process mesh is not a union when process meshes is null
"""Change while op input and output after the corresponding sub block ops removed""" if not process_meshes:
global while_block_info process_meshes.append(op_process_mesh)
for sub_block_idx in while_block_info:
sub_block = auto_parallel_main_prog.blocks[sub_block_idx]
parent_while_op_id = while_block_info[sub_block_idx]["op_id"]
parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx]
sub_block_op_inputs = set() return process_meshes
sub_block_op_outputs = []
for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name)
for var_name in op.input_arg_names:
sub_block_op_inputs.add(var_name)
# find the while op def get_while_op_actual_process_mesh(self, op):
while_op = None """Get the while op actual Process mesh corresponding to rank"""
for op in parent_block.ops: assert op.type == "while"
if op.desc.id() == parent_while_op_id and op.type == "while": while_op_process_mesh = self.dist_context.get_dist_op_for_program(
while_op = op op).dist_attr.process_mesh
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue
if self.rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break break
assert while_op is not None if actual_process_mesh is None and self.rank_id in while_op_process_mesh.processes:
actual_process_mesh = while_op_process_mesh
# find the actual input and output of while op assert actual_process_mesh is not None
proto = OpProtoHolder.instance().get_op_proto(while_op.type) return actual_process_mesh
new_X = []
for var_name in while_op.input("X"):
if var_name in sub_block_op_inputs:
new_X.append(var_name)
assert new_X
while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = [] def find_op_desc_seq(self, dist_tensor, dist_op, actual_process_mesh):
for var_name in while_op.output("Out"): """
for output_name in sub_block_op_outputs[::-1]: Find the op description sequence to reshard the source tensor for matching the op requirement.
if output_name.find(var_name) != -1:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, Returns:
dist_params_grads): Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
"""Remove no need vars and ops in the main program.""" process and value is a list containing op description.
_remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id) """
_change_while_op_input_and_output(auto_parallel_main_prog, dist_context) tensor_dist_attr = dist_tensor.dist_attr
_remove_no_need_vars(auto_parallel_main_prog, dist_params_grads) source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr
target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
def remove_no_need_in_startup(auto_parallel_main_prog, if source_tensor.shape[0] < 0:
auto_parallel_startup_prog): new_shape = list(source_tensor.shape)
"""Remove no need vars and ops in the startup program.""" new_shape[0] = self.batch_size
main_input_vars = set() source_tensor.desc.set_shape(new_shape)
main_ops = auto_parallel_main_prog.global_block().ops
for op in main_ops:
for var_name in op.input_arg_names:
main_input_vars.add(var_name)
startup_block = auto_parallel_startup_prog.global_block() complete_shape = Resharder.compute_complete_shape(
startup_output_vars = set() source_tensor.shape, source_process_shape, source_dims_mapping)
startup_ops = startup_block.ops op_desc_seq = {}
for op in startup_ops:
# skip c_sync_comm_stream op
if op.type == "c_sync_comm_stream":
continue
for var_name in op.output_arg_names:
startup_output_vars.add(var_name)
need_vars = set() # TODO: if the target process group has the same process with source process group
for var_name in startup_output_vars: if set(target_process_group).intersection(set(
if var_name in main_input_vars: source_process_group)) and set(target_process_group).difference(
need_vars.add(var_name) set(source_process_group)):
pass
startup_ops = startup_block.ops # in the different process group, it will use send, recv, concat and slice op
actual_need_vars = set() elif target_process_group != source_process_group:
for idx, op in enumerate(startup_ops): partition_process_mapping_list = []
is_need_op = False for source_process in source_process_group:
if op.type == "c_sync_comm_stream": source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \
continue source_process_shape, source_process_group)
for var_name in op.output_arg_names: if not partition_process_mapping_list:
if var_name in need_vars: partition_process_mapping_list.append(
is_need_op = True [source_partition_index, [source_process], [False]])
else:
partition_list = list(
[item[0] for item in partition_process_mapping_list])
process_list = list(
[item[1] for item in partition_process_mapping_list])
has_used = list(
[item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index)
process_list[index].append(source_process)
has_used[index].append(False)
else:
partition_process_mapping_list.append([
source_partition_index, [source_process], [False]
])
for target_process in target_process_group:
has_sent = []
target_partition_index = Resharder.compute_partition_index(
target_process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
partition_index_list = []
all_partition_index_list = []
for source_process in source_process_group:
source_partition_index = Resharder.compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
to_send_process = None
if all(_ for _ in list(map(self.is_overlapped, source_partition_index, target_partition_index))) \
and source_partition_index not in has_sent:
idx = list([
item[0] for item in partition_process_mapping_list
]).index(source_partition_index)
has_used = list([
item[2] for item in partition_process_mapping_list
])[idx]
process_list = list([
item[1] for item in partition_process_mapping_list
])[idx]
i = 0
while i < len(has_used):
if not has_used[i]:
to_send_process = process_list[i]
has_used[i] = True
break break
if is_need_op: i += 1
for var_name in op.output_arg_names: if i == len(has_used):
actual_need_vars.add(var_name) has_used = list(map(lambda x: False, has_used))
for var_name in op.input_arg_names: to_send_process = process_list[0]
actual_need_vars.add(var_name) has_used[0] = True
assert to_send_process is not None, "Failed to find the send process."
remove_vars = set() if to_send_process not in op_desc_seq.keys():
for var_name in startup_block.vars: op_desc_seq[to_send_process] = []
if var_name not in actual_need_vars: if target_process not in op_desc_seq.keys():
remove_vars.add(var_name) op_desc_seq[target_process] = []
for var in remove_vars: all_partition_index_list.append(source_partition_index)
startup_block._remove_var(var)
remove_op_idx = [] # append send and recv op desc
vars = startup_block.vars send_op_desc = SendOpDesc(source_partition_index,
for idx, op in enumerate(startup_block.ops): target_process)
is_no_need_op = False recv_op_desc = RecvOpDesc(source_partition_index,
if op.type == "c_sync_comm_stream": to_send_process)
var_names = [] op_desc_seq[to_send_process].append(send_op_desc)
for var_name in op.input_arg_names: op_desc_seq[target_process].append(recv_op_desc)
if var_name in vars: has_sent.append(source_partition_index)
var_names.append(var_name) Resharder.concat_partitions(partition_index_list,
if not var_names: source_partition_index)
remove_op_idx.append(idx)
else:
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, var_names)
op.desc.set_output(proto.outputs[0].name, var_names)
continue
for var_name in op.output_arg_names: # append concat op desc
if var_name not in vars: op_desc_seq[target_process].append(
is_no_need_op = True ConcatOpDesc(all_partition_index_list))
break
if is_no_need_op:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
startup_block._remove_op(idx)
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
concatenated_partition_index = partition_index_list[0]
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] - item[
0])
slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx)
op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes))
def _get_process_meshes(op, program, dist_context): # in the same process group, it will use allgahther and slice op
"""Get all process meshes when op has sub block.""" else:
assert op.has_attr("sub_block") partition_index_list = []
sub_block = program.blocks[op.attr("sub_block").id] all_partition_index_list = []
ops = sub_block.ops process_index = []
op_process_mesh = dist_context.get_dist_op_for_program( for source_process in source_process_group:
op).dist_attr.process_mesh source_partition_index = Resharder.compute_partition_index(
process_meshes = [] source_process, complete_shape, source_dims_mapping,
for op in ops: source_process_shape, source_process_group)
dist_op = dist_context.get_dist_op_for_program(op) if source_partition_index not in partition_index_list:
if not dist_op: partition_index_list.append(source_partition_index)
continue process_index.append(
process_mesh = dist_op.dist_attr.process_mesh [[source_process, ], source_partition_index])
if process_mesh not in process_meshes and process_mesh != op_process_mesh: else:
process_meshes.append(process_mesh) process_index[partition_index_list.index(
source_partition_index)][0].append(source_process)
if not process_meshes: for i in range(len(process_index[0][0])):
process_meshes.append(op_process_mesh) group = []
for j in range(len(process_index)):
group.append(process_index[j][0][i])
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
target_partition_index = Resharder.compute_partition_index(
process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
slices_axes.append(idx)
return process_meshes slice_op_desc = SliceOpDesc(
starts=slice_starts, ends=slice_ends, axes=slices_axes)
op_desc_seq[process] = [AllGatherOpDesc(group=group),
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \
if len(group) > 1 else [slice_op_desc]
return op_desc_seq
def _is_condition_replicative(op, program, dist_context): def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
assert op.type == "while" actual_process_mesh):
sub_block = program.blocks[op.attr("sub_block").id] """Parse op desc sequence and insert op in the block"""
dist_op = dist_context.get_dist_op_for_program(op) tensor_list = []
op_dist_attr = dist_op.dist_attr partition_tensor_list = []
if self.rank_id not in op_desc_seq.keys():
return
op_desc_list = op_desc_seq[self.rank_id]
# the dims mapping of condition tensor should be replicative idx = None
for var_name in op.input("Condition"): for index, op in list(enumerate(block.ops)):
var = _get_var(var_name, sub_block, program) if op.desc.id == reshard_op.desc.id:
dist_tensor = dist_context.get_dist_tensor_for_program(var) idx = index
tensor_dist_attr = dist_tensor.dist_attr break
var_dims_mapping = tensor_dist_attr.dims_mapping assert idx is not None, "The op for reshard cannot be found in the rank {} program.".format(
for dim in var_dims_mapping: self.rank_id)
if dim != -1:
return False
return True matched_op = block.ops[idx]
source_tensor = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog)
for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc): # noqa: F401
if var_name not in self.has_allgather.keys():
self.has_allgather[var_name] = []
if not self.has_allgather[
var_name] or op_desc.group not in list(
map(lambda x: x[0], self.has_allgather[var_name])):
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list]
self.has_allgather[var_name].append(
[op_desc.group, tensor_name_list])
else:
for item in self.has_allgather[var_name]:
if op_desc.group == item[0]:
tensor_list = [
program.global_block().vars[var_name]
for var_name in item[1]
]
break
assert tensor_list, "The result of parsing allgather op should not be None."
elif isinstance(op_desc, SendOpDesc):
if var_name not in self.has_sent.keys():
self.has_sent[var_name] = []
if op_desc.dst not in self.has_sent[var_name]:
Inserter.insert_send_op(block, idx, source_tensor,
op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
self.has_sent[var_name].append(op_desc.dst)
def _get_op_process_meshes(op, dist_context): elif isinstance(op_desc, RecvOpDesc):
process_meshes = [] if var_name not in self.has_recv.keys():
dist_op = dist_context.get_dist_op_for_program(op) self.has_recv[var_name] = {}
op_process_mesh = dist_op.dist_attr.process_mesh if op_desc.src not in self.has_recv[var_name].keys():
for process_mesh in dist_context.process_meshes: partition_index = op_desc.partition_index
if set(process_mesh.processes) & ( shape = []
set(op_process_mesh.processes) for index in partition_index:
) and len(process_mesh.processes) <= len(op_process_mesh.processes): shape.append(index[1] - index[0])
process_meshes.append(process_mesh) recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
dtype=source_tensor.dtype,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src,
reshard_op.attr('op_role'))
tensor_list.append(recv_tensor)
idx += 1
self.has_recv[var_name][op_desc.src] = recv_tensor
else:
tensor_list.append(self.has_recv[var_name][op_desc.src])
# it means the process mesh is not a union when process meshes is null elif isinstance(op_desc, ConcatOpDesc):
if not process_meshes: partition_index_list = op_desc.partition_index_list
process_meshes.append(op_process_mesh) idx_list = [idx]
for index, tensor in enumerate(tensor_list):
Inserter.concat_partitions_with_op(
partition_tensor_list, tensor,
partition_index_list[index], block, idx_list,
reshard_op.attr('op_role'))
idx = idx_list[0]
return process_meshes elif isinstance(op_desc, SliceOpDesc):
assert len(
partition_tensor_list) == 1 or not partition_tensor_list
to_slice_tensor = partition_tensor_list[0][0] if len(
partition_tensor_list) == 1 else source_tensor
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = Inserter.insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
tensor_attr = TensorDistributedAttribute()
process_mesh = actual_process_mesh
dims_mapping = self.dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
target_tensor, tensor_attr)
def reshard(auto_parallel_main_prog, if op.type == "while":
auto_parallel_startup_prog, # var_reshard_mapping means the while op input need be changed to
rank_id, if "var_reshard_mapping" not in Resharder.while_block_info[
dist_context, op.attr("sub_block").id].keys():
dist_params_grads, Resharder.while_block_info[op.attr("sub_block").id][
batch_size=None): "var_reshard_mapping"] = {}
""" Resharder.while_block_info[op.attr("sub_block").id][
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. "var_reshard_mapping"][var_name] = target_tensor.name
Args: # rename op input name according to new name
auto_parallel_main_prog (Program): An auto parallel main program. for op in block.ops:
auto_parallel_startup_prog (Program): An auto parallel startup program. for name in op.input_arg_names:
rank_id (int): The process id. op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
dist_context (DistributedContext): The distributed context of this rank. op)
dist_params_grads (list): The list contains the tuple of param and grad. if name == var_name and op_dist_attr is not None:
""" if op.desc.id() == matched_op.desc.id():
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ op.desc._rename_input(name, target_tensor.name)
"but got {}.".format(type(auto_parallel_main_prog)) op_dist_attr.set_input_dims_mapping(
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \ target_tensor.name, dims_mapping)
"but got {}.".format(type(auto_parallel_startup_prog)) op_dist_attr.set_input_dist_attr(name, None)
assert isinstance(rank_id, int), "The type of rank_id should be int, " \ continue
"but got {}.".format(type(rank_id))
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
def _is_special_op(op): # NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
global _g_special_ops op_process_mesh = op_dist_attr.process_mesh
if op.type in _g_special_ops: op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
return True var_name)
return False if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
global while_block_info def reshard(self):
for block_idx, block in enumerate(auto_parallel_main_prog.blocks): for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks):
if block_idx in while_block_info: if block_idx in Resharder.while_block_info:
if "var_reshard_mapping" in while_block_info[block_idx]: if "var_reshard_mapping" in Resharder.while_block_info[
var_reshard_mapping = while_block_info[block_idx][ block_idx]:
var_reshard_mapping = Resharder.while_block_info[block_idx][
"var_reshard_mapping"] "var_reshard_mapping"]
for op in block.ops: for op in block.ops:
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
if var_name in var_reshard_mapping: if var_name in var_reshard_mapping:
op.desc._rename_input(var_name, op.desc._rename_input(
var_reshard_mapping[var_name]) var_name, var_reshard_mapping[var_name])
dist_op = dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == while_block_info[ if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]: block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_input_dims_mapping( dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name) var_name)
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
var_reshard_mapping[var_name], dims_mapping) var_reshard_mapping[var_name],
op_dist_attr.set_input_dist_attr(var_name, None) dims_mapping)
op_dist_attr.set_input_dist_attr(var_name,
None)
# the outputs also need to be renamed when the output name is the same with input name # the outputs also need to be renamed when the output name is the same with input name
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
if var_name in var_reshard_mapping: if var_name in var_reshard_mapping:
op.desc._rename_output( op.desc._rename_output(
var_name, var_reshard_mapping[var_name]) var_name, var_reshard_mapping[var_name])
dist_op = dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == while_block_info[ if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]: block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_output_dims_mapping( dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name) var_name)
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
var_reshard_mapping[var_name], dims_mapping) var_reshard_mapping[var_name],
dims_mapping)
op_dist_attr.set_output_dist_attr(var_name, op_dist_attr.set_output_dist_attr(var_name,
None) None)
...@@ -1313,31 +1392,31 @@ def reshard(auto_parallel_main_prog, ...@@ -1313,31 +1392,31 @@ def reshard(auto_parallel_main_prog,
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
op = block.ops[idx] op = block.ops[idx]
if _is_special_op(op): if self.is_special_op(op):
idx += 1 idx += 1
continue continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None: if dist_op is not None:
process_meshes = [] process_meshes = []
if op.type == "while": if op.type == "while":
if not _is_condition_replicative( if not self.is_condition_replicative(op):
op, auto_parallel_main_prog, dist_context):
raise ValueError( raise ValueError(
"Please check the condition due to the dims mapping is not replicative." "Please check the condition due to the dims mapping is not replicative."
) )
process_meshes = _get_process_meshes( process_meshes = self.get_process_meshes(op)
op, auto_parallel_main_prog, dist_context)
assert process_meshes assert process_meshes
if op.attr("sub_block").id not in while_block_info: if op.attr("sub_block"
while_block_info[op.attr("sub_block").id] = {} ).id not in Resharder.while_block_info:
while_block_info[op.attr("sub_block").id][ Resharder.while_block_info[op.attr("sub_block")
.id] = {}
Resharder.while_block_info[op.attr("sub_block").id][
"op_id"] = op.desc.id() "op_id"] = op.desc.id()
while_block_info[op.attr("sub_block").id][ Resharder.while_block_info[op.attr("sub_block").id][
"actual_process_mesh"] = _get_while_op_actual_process_mesh( "actual_process_mesh"] = self.get_while_op_actual_process_mesh(
op, auto_parallel_main_prog, rank_id, dist_context) op)
else: else:
process_meshes = _get_op_process_meshes(op, dist_context) process_meshes = self.get_op_process_meshes(op)
input_vars = None input_vars = None
if op.type == "while": if op.type == "while":
input_var_names = op.input("X") input_var_names = op.input("X")
...@@ -1348,17 +1427,17 @@ def reshard(auto_parallel_main_prog, ...@@ -1348,17 +1427,17 @@ def reshard(auto_parallel_main_prog,
# skip lod_tensor_blocking_queue_0 # skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0": if var_name == "lod_tensor_blocking_queue_0":
continue continue
var = _get_var(var_name, block, auto_parallel_main_prog) var = get_var_with_recursion(
dist_tensor = dist_context.get_dist_tensor_for_program(var) var_name, block, self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
for process_mesh in process_meshes: for process_mesh in process_meshes:
if dist_tensor is not None and _need_reshard( if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh, dist_tensor, dist_op, process_mesh):
auto_parallel_main_prog, dist_context): reshard_op_desc = self.find_op_desc_seq(
reshard_op_desc = find_op_desc_seq( dist_tensor, dist_op, process_mesh)
dist_tensor, dist_op, process_mesh, batch_size) self.parse_op_desc(block, reshard_op_desc,
parse_op_desc(block, rank_id, reshard_op_desc, var_name, op, process_mesh)
var_name, op, dist_context,
auto_parallel_main_prog, process_mesh)
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count pre_op_count = cur_op_count
...@@ -1370,31 +1449,35 @@ def reshard(auto_parallel_main_prog, ...@@ -1370,31 +1449,35 @@ def reshard(auto_parallel_main_prog,
idx = 0 idx = 0
# skip reader and ops whose process mesh is union # skip reader and ops whose process mesh is union
skip_ops = [ skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while", "create_py_reader", "create_double_buffer_reader", "read",
"write_to_array", "read_from_array" "while", "write_to_array", "read_from_array"
] ]
global _g_special_ops
skip_ops += _g_special_ops skip_ops += _g_special_ops
while idx < len(block.ops): while idx < len(block.ops):
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
op = block.ops[idx] op = block.ops[idx]
dist_op = dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops: if dist_op is not None and op.type not in skip_ops:
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
var = _get_var(var_name, block, auto_parallel_main_prog) var = get_var_with_recursion(
dist_tensor = dist_context.get_dist_tensor_for_program(var) var_name, block, self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
if dist_tensor is not None and _need_reshard( if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh, dist_tensor, dist_op, process_mesh, False):
auto_parallel_main_prog, dist_context, False):
for index, item in enumerate( for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes): dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[ recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index] index]
if rank_id == item: if self.rank_id == item:
_insert_send_op(block, idx + 1, var, recv_rank, Inserter.insert_send_op(block, idx + 1, var,
recv_rank,
op.attr('op_role')) op.attr('op_role'))
if rank_id == recv_rank: if self.rank_id == recv_rank:
_insert_recv_op(block, idx + 1, var, item, Inserter.insert_recv_op(block, idx + 1, var,
item,
op.attr('op_role')) op.attr('op_role'))
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count idx_offset = idx_offset + cur_op_count - pre_op_count
...@@ -1404,9 +1487,13 @@ def reshard(auto_parallel_main_prog, ...@@ -1404,9 +1487,13 @@ def reshard(auto_parallel_main_prog,
idx += 1 idx += 1
# remove no need vars and ops in the main program # remove no need vars and ops in the main program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, Remover.remove_no_need_in_main(self.auto_parallel_main_prog,
dist_params_grads) self.dist_context, self.rank_id,
self.dist_params_grads)
# remove no need vars and ops in the startip program # remove no need vars and ops in the startip program
remove_no_need_in_startup(auto_parallel_main_prog, Remover.remove_no_need_in_startup(self.auto_parallel_main_prog,
auto_parallel_startup_prog) self.auto_parallel_startup_prog)
# reset some variable when remove operation ended
Resharder.while_block_info = {}
...@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
def _merge_parameter_with_dist_attr(param_list, dist_attr): def _merge_parameter_with_dist_attr(param_list, dist_attr):
""" Merge parameter with distributed attribute """ """ Merge parameter with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# get the complete shape of the parameter # get the complete shape of the parameter
complete_shape = _compute_complete_shape(param_list[0].shape, process_shape, complete_shape = Resharder.compute_complete_shape(
dims_mapping) param_list[0].shape, process_shape, dims_mapping)
# merge the parameter with dist_attr # merge the parameter with dist_attr
partition_param_list = [] partition_param_list = []
merged_partiton = [] merged_partiton = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group)
index = process_group.index(process) index = process_group.index(process)
if partition_index not in merged_partiton: if partition_index not in merged_partiton:
...@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index, ...@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
_merge_parameter(partition_param_list, param, partition_index) _merge_parameter(partition_param_list, param, partition_index)
# partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] # partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
""" """
from .reshard import _compute_concat_info from .reshard import Resharder
if len(partition_param_list) == 1: if len(partition_param_list) == 1:
is_complete_data = True is_complete_data = True
...@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index, ...@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
else: else:
i = 0 i = 0
while i < len(partition_param_list): while i < len(partition_param_list):
concat_axis, first_order, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_param_list[i][1], partition_index) partition_param_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
if first_order == 0: if first_order == 0:
...@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, ...@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_shape, process_group) process_shape, process_group)
# index: 2 # index: 2
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
rank, complete_shape, dims_mapping, process_shape, process_group) rank, complete_shape, dims_mapping, process_shape, process_group)
sliced_param_index = 0 sliced_param_index = 0
for i, shape in enumerate(complete_shape): for i, shape in enumerate(complete_shape):
...@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape, ...@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]] # index: [[], [], [2, 4]]
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
split_indices_list = [] split_indices_list = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group)
if split_indices_list: if split_indices_list:
for dim in range(len(partition_index)): for dim in range(len(partition_index)):
......
...@@ -31,7 +31,6 @@ from paddle.distributed import fleet ...@@ -31,7 +31,6 @@ from paddle.distributed import fleet
from paddle.fluid.initializer import NumpyArrayInitializer from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program
from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program
from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static() paddle.enable_static()
...@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase): ...@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase):
paddle.seed(2021) paddle.seed(2021)
random.seed(2021) random.seed(2021)
np.random.seed(2021) np.random.seed(2021)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
def tearDown(self): def tearDown(self):
os.remove("./model_state_rank{}.pdmodel".format( os.remove("./model_state_rank{}.pdmodel".format(
......
...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.cost_model import estimate_cost from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase): ...@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog( distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(distributed_program, dist_startup_prog, rank_id, resharder = Resharder(distributed_program, dist_startup_prog,
dist_context, dist_params_grads) rank_id, dist_context, dist_params_grads)
resharder.reshard()
dist_program.append(distributed_program) dist_program.append(distributed_program)
cluster = None cluster = None
cost = estimate_cost( cost = estimate_cost(
......
...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer ...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import get_all_process_groups from paddle.distributed.auto_parallel.process_group import get_all_process_groups
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.cluster import Cluster
...@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads) dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_train_program, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
return dist_train_program, dist_startup_prog return dist_train_program, dist_startup_prog
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase):
self.assertTrue(check_initialization(dist_startup_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_pp_diff_process_mesh(self): def test_mlp_pp_diff_process_mesh(self):
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True) train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context) print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
...@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0 rank_id = 0
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# send and recv should not exist in dp scene. # send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# print_program_with_dist_attr(dist_main_prog, dist_context) # print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context.block_state.parse_forward_blocks(complete_train_program) dist_context.block_state.parse_forward_blocks(complete_train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
dist_context, partitioned_params_grads) rank_id, dist_context, partitioned_params_grads)
resharder.reshard()
# the x should not be slice # the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog)) self.assertTrue(check_allgather(partitioned_main_prog))
......
...@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto ...@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册