未验证 提交 d101334c 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard (#40865)

* fix code stype

* update unitest
上级 86554d91
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .reshard import reshard # noqa: F401 from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost from .cost_model import estimate_cost
__all__ = [] __all__ = []
...@@ -235,19 +235,19 @@ class Converter(object): ...@@ -235,19 +235,19 @@ class Converter(object):
@staticmethod @staticmethod
def merge_with_dist_attr(tensor_list, dist_attr): def merge_with_dist_attr(tensor_list, dist_attr):
""" Merge tensor with distributed attribute """ """ Merge tensor with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# get the complete shape of the tensor # get the complete shape of the tensor
complete_shape = _compute_complete_shape(tensor_list[0].shape, complete_shape = Resharder.compute_complete_shape(
process_shape, dims_mapping) tensor_list[0].shape, process_shape, dims_mapping)
# merge the tensor with dist_attr # merge the tensor with dist_attr
partition_tensor_list = [] partition_tensor_list = []
merged_partiton = [] merged_partiton = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process, complete_shape, dims_mapping, process_shape,
process_group) process_group)
index = process_group.index(process) index = process_group.index(process)
...@@ -302,7 +302,7 @@ class Converter(object): ...@@ -302,7 +302,7 @@ class Converter(object):
_merge_tensor(partition_tensor_list, tensor, partition_index) _merge_tensor(partition_tensor_list, tensor, partition_index)
# partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] # partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
""" """
from .reshard import _compute_concat_info from .reshard import Resharder
if len(partition_tensor_list) == 1: if len(partition_tensor_list) == 1:
is_complete_data = True is_complete_data = True
...@@ -318,7 +318,7 @@ class Converter(object): ...@@ -318,7 +318,7 @@ class Converter(object):
else: else:
i = 0 i = 0
while i < len(partition_tensor_list): while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i][1], partition_index) partition_tensor_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
if first_order == 0: if first_order == 0:
...@@ -391,11 +391,11 @@ class Converter(object): ...@@ -391,11 +391,11 @@ class Converter(object):
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]] # index: [[], [], [2, 4]]
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
split_indices_list = [] split_indices_list = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process, complete_shape, dims_mapping, process_shape,
process_group) process_group)
if split_indices_list: if split_indices_list:
...@@ -437,9 +437,9 @@ class Converter(object): ...@@ -437,9 +437,9 @@ class Converter(object):
process_shape, process_group) process_shape, process_group)
# index: 2 # index: 2
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
rank_id, complete_shape, dims_mapping, process_shape, process_group) rank_id, complete_shape, dims_mapping, process_shape, process_group)
sliced_index = 0 sliced_index = 0
for i, shape in enumerate(complete_shape): for i, shape in enumerate(complete_shape):
......
...@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger ...@@ -32,7 +32,7 @@ from paddle.distributed.utils import get_logger
from .mapper import mapping from .mapper import mapping
from .cluster import Cluster from .cluster import Cluster
from .reshard import reshard from .reshard import Resharder
from .planner import Planner from .planner import Planner
from .completion import Completer from .completion import Completer
from .partitioner import Partitioner from .partitioner import Partitioner
...@@ -187,8 +187,9 @@ class Engine: ...@@ -187,8 +187,9 @@ class Engine:
# Do reshard process # Do reshard process
set_grad_var_shape(dist_main_prog, dist_context) set_grad_var_shape(dist_main_prog, dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# Apply post optimization passes # Apply post optimization passes
self._apply_post_optimization(dist_main_prog, dist_startup_prog, self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
...@@ -199,8 +200,9 @@ class Engine: ...@@ -199,8 +200,9 @@ class Engine:
serial_main_program, serial_startup_program, []) serial_main_program, serial_startup_program, [])
# Do reshard process # Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, [], resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
1) dist_context, [], 1)
resharder.reshard()
# clone program for test # clone program for test
if mode != 'train': if mode != 'train':
......
...@@ -42,7 +42,7 @@ from .utils import make_data_unshard ...@@ -42,7 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from .reshard import Resharder
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -213,17 +213,15 @@ class AutoParallelizer: ...@@ -213,17 +213,15 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_params_grads) self._dist_context, dist_params_grads)
resharder.reshard()
self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
g_process_group_map = None g_process_group_map = None
if not relaunch_phase: if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map) g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear() _g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes: for process_mesh in dist_context._process_meshes:
......
...@@ -29,7 +29,19 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map ...@@ -29,7 +29,19 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded. # NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
while_block_info = {}
def get_var_with_recursion(var_name, block, program):
"""Get var in the parent block if not found in the current block"""
var = None
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
return var
class AllGatherOpDesc: class AllGatherOpDesc:
...@@ -157,7 +169,7 @@ class ConcatOpDesc: ...@@ -157,7 +169,7 @@ class ConcatOpDesc:
Describe the concat op in the reshard phase. Describe the concat op in the reshard phase.
Args: Args:
partition_index_list (list): A list contains all partition index. partition_index_list (list): The list contains all partition index.
""" """
def __init__(self, partition_index_list): def __init__(self, partition_index_list):
...@@ -176,482 +188,107 @@ class ConcatOpDesc: ...@@ -176,482 +188,107 @@ class ConcatOpDesc:
return f"op: {self._desc}, partition_index_list: {self._partition_index_list}." return f"op: {self._desc}, partition_index_list: {self._partition_index_list}."
def _compute_partition_shape(complete_shape, dims_mapping, process_shape): class Inserter:
"""Compute the shape of partition.""" """Insert op required in the reshard process."""
partition_shape = []
for idx, item in enumerate(complete_shape):
if dims_mapping[idx] == -1:
partition_shape.append(item)
else:
partition_shape.append(item // process_shape[dims_mapping[idx]])
return partition_shape
def _compute_process_index(process, process_group, process_shape):
"""Compute the index of process_shape corresponding to the process."""
relative_process = process_group.index(process)
process_index = []
product = reduce(lambda x, y: x * y, process_shape)
for i in range(len(process_shape)):
idx = relative_process // (product // process_shape[i])
product = product // process_shape[i]
relative_process = relative_process - relative_process // product * product
process_index.append(idx)
return process_index
def _compute_partition_index(process, complete_shape, dims_mapping,
process_shape, process_group):
"""Compute the partition index in complete tensor."""
partition_shape = _compute_partition_shape(complete_shape, dims_mapping,
process_shape)
process_index = _compute_process_index(process, process_group,
process_shape)
partition_index = []
for i in range(len(complete_shape)):
if dims_mapping[i] == -1:
partition_index.append([0, partition_shape[i]])
else:
partition_index.append([
process_index[dims_mapping[i]] * partition_shape[i],
(process_index[dims_mapping[i]] + 1) * partition_shape[i]
])
return partition_index
def _compute_concat_info(partition_index_x, partition_index_y):
"""Judge whether two partition can be concatenated and compute concatenated partition index."""
differ_count = 0
concat_axis = -1
first_order = 0
new_partition = []
for idx, item in enumerate(partition_index_x):
if item != partition_index_y[idx]:
differ_count += 1
if item[1] == partition_index_y[idx][0] and item[
0] < partition_index_y[idx][1]:
concat_axis = idx
new_partition.append([item[0], partition_index_y[idx][1]])
elif item[0] == partition_index_y[idx][1] and item[
1] > partition_index_y[idx][0]:
first_order = 1
concat_axis = idx
new_partition.append([partition_index_y[idx][0], item[1]])
else:
new_partition.append(item)
if differ_count == 1:
return concat_axis, first_order, new_partition
else:
return -1, first_order, new_partition
def _concat_partitions(partition_index_list, partition_index):
"""Concat the given partitions without inserting concat op."""
if not partition_index_list:
partition_index_list.append(partition_index)
else:
i = 0
has_concat = False
while i < len(partition_index_list):
concat_axis, _, new_partition = _compute_concat_info(
partition_index_list[i], partition_index)
if concat_axis != -1:
has_concat = True
partition_index_list.pop(i)
_concat_partitions(partition_index_list, new_partition)
break
i += 1
if not has_concat:
partition_index_list.append(partition_index)
def _is_overlapped(shape_x, shape_y):
"""Judge whether two partitions intersect on the specified dimension."""
overlapped = False
if (shape_y[0] <= shape_x[0] < shape_y[1]) or (
shape_x[0] <= shape_y[0] < shape_x[1]):
overlapped = True
return overlapped
def _need_reshard(dist_tensor,
dist_op,
actual_process_mesh,
program,
dist_context,
op_input=True):
"""Judge the tensor whether needs to be resharded."""
def _is_unshard(dims_mapping):
for dim in dims_mapping:
if dim != -1:
return False
return True
is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name
tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = actual_process_mesh
if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_input_dims_mapping, op_process_mesh
])):
# dims_mapping
if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = program.blocks[dist_op.serial_op.attr(
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert _is_unshard(tensor_dims_mapping)
assert _is_unshard(op_input_dims_mapping)
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
# for while op, it should find the process mesh of op actually used the tensor as input
if dist_op.serial_op.type == "while":
sub_block = program.blocks[dist_op.serial_op.attr(
"sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = dist_context.get_dist_op_for_program(
op).dist_attr
process_mesh = dist_op_attr.process_mesh
if process_mesh == op_process_mesh:
is_reshard = True
break
else:
is_reshard = True
else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh
])):
if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
return is_reshard
def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
"""compute the complete shape of the slice tensor with its process mesh and dims mapping"""
complete_shape = []
for idx, item in enumerate(slice_shape):
if dims_mapping[idx] == -1:
complete_shape.append(item)
else:
complete_shape.append(item * process_shape[dims_mapping[idx]])
return complete_shape
def find_op_desc_seq(dist_tensor, dist_op, actual_process_mesh, batch_size):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
process and value is a list containing op description.
"""
tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr
target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0:
new_shape = list(source_tensor.shape)
new_shape[0] = batch_size
source_tensor.desc.set_shape(new_shape)
complete_shape = _compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping)
op_desc_seq = {}
# TODO: if the target process group has the same process with source process group
if set(target_process_group).intersection(set(
source_process_group)) and set(target_process_group).difference(
set(source_process_group)):
pass
# in the different process group, it will use send, recv, concat and slice op
elif target_process_group != source_process_group:
partition_process_mapping_list = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(source_process, complete_shape, source_dims_mapping, \
source_process_shape, source_process_group)
if not partition_process_mapping_list:
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
else:
partition_list = list(
[item[0] for item in partition_process_mapping_list])
process_list = list(
[item[1] for item in partition_process_mapping_list])
has_used = list(
[item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index)
process_list[index].append(source_process)
has_used[index].append(False)
else:
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
for target_process in target_process_group: @staticmethod
has_sent = [] def insert_send_op(block, idx, tensor, dst, op_role):
target_partition_index = _compute_partition_index( """Insert send op into block at the given index."""
target_process, complete_shape, target_dims_mapping, op_type = 'send_v2'
target_process_shape, target_process_group) block._insert_op(
partition_index_list = [] idx,
all_partition_index_list = [] type=op_type,
for source_process in source_process_group: inputs={'X': [tensor]},
source_partition_index = _compute_partition_index( attrs={
source_process, complete_shape, source_dims_mapping, 'ring_id': 0,
source_process_shape, source_process_group) 'peer': dst,
to_send_process = None 'use_calc_stream': True,
if all(_ for _ in list(map(_is_overlapped, source_partition_index, target_partition_index))) \ 'op_role': op_role
and source_partition_index not in has_sent: })
idx = list([
item[0] for item in partition_process_mapping_list @staticmethod
]).index(source_partition_index) def insert_recv_op(block, idx, tensor, src, op_role):
has_used = list( """Insert recv op into block at the given index."""
[item[2] op_type = 'recv_v2'
for item in partition_process_mapping_list])[idx] block._insert_op(
process_list = list( idx,
[item[1] type=op_type,
for item in partition_process_mapping_list])[idx] inputs={'X': [tensor]},
i = 0 outputs={'Out': [tensor]},
while i < len(has_used): attrs={
if not has_used[i]: 'ring_id': 0,
to_send_process = process_list[i] 'peer': src,
has_used[i] = True 'out_shape': tensor.shape,
break 'dtype': tensor.dtype,
i += 1 'use_calc_stream': True,
if i == len(has_used): 'op_role': op_role
has_used = list(map(lambda x: False, has_used)) })
to_send_process = process_list[0]
has_used[0] = True @staticmethod
assert to_send_process is not None, "Failed to find the send process." def insert_concat_op(block, idx, tensors, axis, op_role):
"""Insert concat op into block at the given block."""
if to_send_process not in op_desc_seq.keys(): inputs = {'X': tensors}
op_desc_seq[to_send_process] = [] attrs = {}
if target_process not in op_desc_seq.keys(): attrs['axis'] = axis
op_desc_seq[target_process] = [] attrs['op_role'] = op_role
all_partition_index_list.append(source_partition_index) helper = LayerHelper('concat', **locals())
with paddle.static.program_guard(block.program):
# append send and recv op desc out = helper.create_variable_for_type_inference(
send_op_desc = SendOpDesc(source_partition_index, dtype=helper.input_dtype())
target_process) block._insert_op(
recv_op_desc = RecvOpDesc(source_partition_index, idx,
to_send_process) type='concat',
op_desc_seq[to_send_process].append(send_op_desc) inputs=inputs,
op_desc_seq[target_process].append(recv_op_desc) outputs={'Out': [out]},
has_sent.append(source_partition_index) attrs=attrs)
_concat_partitions(partition_index_list, return out
source_partition_index)
# append concat op desc
op_desc_seq[target_process].append(
ConcatOpDesc(all_partition_index_list))
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
concatenated_partition_index = partition_index_list[0]
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] - item[0])
slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx)
op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes))
# in the same process group, it will use allgahther and slice op
else:
partition_index_list = []
all_partition_index_list = []
process_index = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
if source_partition_index not in partition_index_list:
partition_index_list.append(source_partition_index)
process_index.append(
[[source_process, ], source_partition_index])
else:
process_index[partition_index_list.index(
source_partition_index)][0].append(source_process)
for i in range(len(process_index[0][0])):
group = []
for j in range(len(process_index)):
group.append(process_index[j][0][i])
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
target_partition_index = _compute_partition_index(
process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
slices_axes.append(idx)
slice_op_desc = SliceOpDesc( @staticmethod
starts=slice_starts, ends=slice_ends, axes=slices_axes) def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_desc_seq[process] = [AllGatherOpDesc(group=group), op_role):
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \ """Insert slice op into block at the given block."""
if len(group) > 1 else [slice_op_desc] inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
return op_desc_seq attrs = {
"axes": axes,
"starts": starts,
def _insert_send_op(block, idx, tensor, dst, op_role): "ends": ends,
"""Insert send op into block at the given index.""" "infer_flags": infer_flags,
op_type = 'send_v2'
block._insert_op(
idx,
type=op_type,
inputs={'X': [tensor]},
attrs={
'ring_id': 0,
'peer': dst,
'use_calc_stream': True,
'op_role': op_role
})
def _insert_recv_op(block, idx, tensor, src, op_role):
"""Insert recv op into block at the given index."""
op_type = 'recv_v2'
block._insert_op(
idx,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': 0,
'peer': src,
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': True,
'op_role': op_role 'op_role': op_role
}) }
helper = LayerHelper('slice', **locals())
out = block.create_var(
def _insert_concat_op(block, idx, tensors, axis, op_role): name=new_var_name, dtype=tensor.dtype, type=tensor.type)
"""Insert concat op into block at the given block.""" block._insert_op(
inputs = {'X': tensors} idx,
attrs = {} type="slice",
attrs['axis'] = axis inputs=inputs,
attrs['op_role'] = op_role outputs={'Out': [out]},
helper = LayerHelper('concat', **locals()) attrs=attrs)
with paddle.static.program_guard(block.program): return out
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
block._insert_op(
idx, type='concat', inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_role):
"""Insert slice op into block at the given block."""
inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
attrs = {
"axes": axes,
"starts": starts,
"ends": ends,
"infer_flags": infer_flags,
'op_role': op_role
}
helper = LayerHelper('slice', **locals())
out = block.create_var(
name=new_var_name, dtype=tensor.dtype, type=tensor.type)
block._insert_op(
idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
def _insert_split_op(block, idx, tensor, num_or_sections, op_role):
"""Insert split op into block at the given index."""
helper = LayerHelper('split', **locals())
input_shape = tensor.shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role}
with paddle.static.program_guard(block.program):
outs = [
helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) for i in range(num_or_sections)
]
block._insert_op(
idx, type="split", inputs=inputs, outputs={'Out': outs}, attrs=attrs)
return outs
def _insert_allgather_op(block, idx, tensor, ranks, op_role): @staticmethod
"""Insert allgather op into block at the given index.""" def insert_split_op(block, idx, tensor, num_or_sections, op_role):
"""Insert split op into block at the given index."""
helper = LayerHelper('split', **locals())
input_shape = tensor.shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role}
with paddle.static.program_guard(block.program):
outs = [
helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) for i in range(num_or_sections)
]
block._insert_op(
idx,
type="split",
inputs=inputs,
outputs={'Out': outs},
attrs=attrs)
return outs
def _insert_fill_constant_op(block, idx): @staticmethod
def insert_fill_constant_op(block, idx, op_role):
"""Insert fill constant op into block at the given index.""" """Insert fill constant op into block at the given index."""
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
with paddle.static.program_guard(block.program): with paddle.static.program_guard(block.program):
...@@ -673,740 +310,1190 @@ def _insert_allgather_op(block, idx, tensor, ranks, op_role): ...@@ -673,740 +310,1190 @@ def _insert_allgather_op(block, idx, tensor, ranks, op_role):
out.stop_gradient = True out.stop_gradient = True
return out return out
tensor_list = [] @staticmethod
group = new_process_group(ranks) def insert_allgather_op(block, idx, tensor, ranks, op_role):
idx_offset = 0 """Insert allgather op into block at the given index."""
tensor_list = []
# instant process group before insert allgather op. group = new_process_group(ranks)
if not group.is_instantiate(): idx_offset = 0
# insert fill_constant op
fill_constant_out = _insert_fill_constant_op(block, idx) # instant process group before insert allgather op.
fill_constant_out.stop_gradient = True if not group.is_instantiate():
# insert fill_constant op
# insert c_allreduce_sum op fill_constant_out = Inserter.insert_fill_constant_op(block, idx,
block._insert_op( op_role)
idx + 1, fill_constant_out.stop_gradient = True
type="c_allreduce_sum",
inputs={'X': [fill_constant_out]}, # insert c_allreduce_sum op
outputs={'Out': [fill_constant_out]}, block._insert_op(
attrs={'ring_id': 0, idx + 1,
'use_calc_stream': True, type="c_allreduce_sum",
'op_role': op_role}) inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
# insert c_sync_calc_stream op attrs={
'ring_id': 0,
'use_calc_stream': True,
'op_role': op_role
})
# insert c_sync_calc_stream op
block._insert_op(
idx + 2,
type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'op_role': op_role})
idx_offset = 3
# insert c_allgather op
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
with paddle.static.program_guard(block.program):
allgather_out = helper.create_variable_for_type_inference(
dtype=tensor.dtype)
block._insert_op( block._insert_op(
idx + 2, idx + idx_offset,
type="c_sync_calc_stream", type=op_type,
inputs={'X': [fill_constant_out]}, inputs={'X': [tensor]},
outputs={'Out': [fill_constant_out]}, outputs={'Out': [allgather_out]},
attrs={'op_role': op_role}) attrs={
idx_offset = 3 'ring_id': group.id,
'use_calc_stream': True,
# insert c_allgather op 'nranks': group.nranks,
op_type = 'c_allgather' 'op_role': op_role
helper = LayerHelper(op_type, **locals()) })
with paddle.static.program_guard(block.program): idx_offset += 1
allgather_out = helper.create_variable_for_type_inference(
dtype=tensor.dtype) # insert split op
block._insert_op( split_out = Inserter.insert_split_op(
idx + idx_offset, block, idx + idx_offset, allgather_out, group.nranks, op_role)
type=op_type, idx_offset += 1
inputs={'X': [tensor]}, tensor_list.extend(split_out)
outputs={'Out': [allgather_out]}, return tensor_list, idx_offset
attrs={
'ring_id': group.id, @staticmethod
'use_calc_stream': True, def concat_partitions_with_op(partition_tensor_list, tensor,
'nranks': group.nranks, partition_index, block, idx, op_role):
'op_role': op_role """Concat the tensors and insert concat op."""
}) if not partition_tensor_list:
idx_offset += 1
# insert split op
split_out = _insert_split_op(block, idx + idx_offset, allgather_out,
group.nranks, op_role)
idx_offset += 1
tensor_list.extend(split_out)
return tensor_list, idx_offset
def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
block, idx, op_role):
"""Concat the tensors and insert concat op."""
if not partition_tensor_list:
partition_tensor_list.append((tensor, partition_index))
else:
i = 0
has_concat = False
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info(
partition_tensor_list[i][1], partition_index)
if concat_axis != -1:
has_concat = True
_ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \
if first_order == 0 else \
_insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role)
partition_tensor_list.pop(i)
idx[0] += 1
_concat_partitions_with_op(partition_tensor_list, _,
new_partition, block, idx, op_role)
break
i += 1
if not has_concat:
partition_tensor_list.append((tensor, partition_index)) partition_tensor_list.append((tensor, partition_index))
else:
i = 0
has_concat = False
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i][1], partition_index)
if concat_axis != -1:
has_concat = True
_ = Inserter.insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \
if first_order == 0 else \
Inserter.insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role)
partition_tensor_list.pop(i)
idx[0] += 1
Inserter.concat_partitions_with_op(partition_tensor_list, _,
new_partition, block,
idx, op_role)
break
i += 1
if not has_concat:
partition_tensor_list.append((tensor, partition_index))
class Remover:
"""Remove var and op in the reshard process."""
@staticmethod
def remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
"""Remove no need ops in the main program"""
not_remove_op_ref = [
"create_py_reader", "create_double_buffer_reader", "read"
]
# NOTE: The nested sub block is not be supported now.
remove_block_order = []
for block_idx in Resharder.while_block_info:
remove_block_order.append(block_idx)
HAS_SENT = {} for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
HAS_RECV = {} if block_idx not in remove_block_order:
HAS_ALLGATHER = {} remove_block_order.append(block_idx)
# the sub block should be removed first
def _get_while_op_actual_process_mesh(op, program, rank_id, dist_context): for block_idx in remove_block_order:
"""Get the while op actual Process mesh corresponding to rank""" remove_op_idx = []
assert op.type == "while" block = auto_parallel_main_prog.blocks[block_idx]
while_op_process_mesh = dist_context.get_dist_op_for_program( ops = block.ops
op).dist_attr.process_mesh vars = block.vars
sub_block = program.blocks[op.attr("sub_block").id] for idx, op in enumerate(ops):
ops = sub_block.ops if op.type == "read":
actual_process_mesh = None dim_list = []
for op in ops: for var_name in op.output_arg_names:
dist_op = dist_context.get_dist_op_for_program(op) dim_list.extend(
if not dist_op: get_var_with_recursion(
continue var_name, block, auto_parallel_main_prog).shape)
process_mesh = dist_op.dist_attr.process_mesh for i in range(idx, -1, -1):
if process_mesh == while_op_process_mesh: if ops[i].type == "create_py_reader":
continue ops[i]._set_attr("shape_concat", dim_list)
if rank_id in process_mesh.processes: break
raw_process_mesh = process_mesh continue
break
if actual_process_mesh is None and rank_id in while_op_process_mesh.processes:
actual_process_mesh = while_op_process_mesh
assert actual_process_mesh is not None # replace the input and output of c_sync_comm_stream op when in pipeline scene.
return actual_process_mesh if op.type == "c_sync_comm_stream":
need_save = []
for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
get_var_with_recursion(
var_name, block,
auto_parallel_main_prog)).process_mesh
if rank_id in process_mesh.processes:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
continue
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, need_save)
op.desc.set_output(proto.outputs[0].name, need_save)
continue
def _get_var(var_name, block, program): # judge the other op whether should be removed.
"""Get var in the parent block if not found in the current block""" op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
var = None if op_dist_attr is not None:
if var_name in block.vars: op_process_mesh = op_dist_attr.process_mesh
var = block.vars[var_name] if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
else: remove_op_idx.append(idx)
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars: for idx in remove_op_idx[::-1]:
var = parent_block.vars[var_name] block._remove_op(idx)
assert var is not None
return var @staticmethod
def remove_no_need_vars(auto_parallel_main_prog, dist_params_grads):
"""Remove no need vars in the main program"""
for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
remove_vars = set()
ops = block.ops
vars = block.vars
need_vars = set()
for op in ops:
for var_name in op.input_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var_name in op.output_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var in vars:
if var not in need_vars:
remove_vars.add(var)
# change dist_params_grads, the optimize op just in block 0.
if block_idx == 0:
param_grad_map = {}
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
need_remove_idx = []
for idx, item in enumerate(dist_params_grads):
if item[0].name not in param_grad_map.keys():
need_remove_idx.append(idx)
for idx in need_remove_idx[::-1]:
dist_params_grads.pop(idx)
idx = 0
while idx < len(dist_params_grads):
param_name = dist_params_grads[idx][0].name
grad_name = dist_params_grads[idx][1].name
if grad_name != param_grad_map[param_name]:
dist_params_grads[idx] = (
vars[param_name], vars[param_grad_map[param_name]])
idx += 1
for var in remove_vars:
block._remove_var(var)
@staticmethod
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads):
"""Remove no need vars and ops in the main program."""
Remover.remove_no_need_ops(auto_parallel_main_prog, dist_context,
rank_id)
Resharder.change_while_op_input_and_output(auto_parallel_main_prog,
dist_context)
Remover.remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
@staticmethod
def remove_no_need_in_startup(auto_parallel_main_prog,
auto_parallel_startup_prog):
"""Remove no need vars and ops in the startup program."""
main_input_vars = set()
main_ops = auto_parallel_main_prog.global_block().ops
for op in main_ops:
for var_name in op.input_arg_names:
main_input_vars.add(var_name)
startup_block = auto_parallel_startup_prog.global_block()
startup_output_vars = set()
startup_ops = startup_block.ops
for op in startup_ops:
# skip c_sync_comm_stream op
if op.type == "c_sync_comm_stream":
continue
for var_name in op.output_arg_names:
startup_output_vars.add(var_name)
def parse_op_desc(block, rank_id, op_desc_seq, var_name, reshard_op, need_vars = set()
dist_context, program, actual_process_mesh): for var_name in startup_output_vars:
"""Parse op desc sequence and insert op in the block""" if var_name in main_input_vars:
global HAS_SENT need_vars.add(var_name)
global HAS_RECV
global HAS_ALLGATHER startup_ops = startup_block.ops
tensor_list = [] actual_need_vars = set()
partition_tensor_list = [] for idx, op in enumerate(startup_ops):
if rank_id not in op_desc_seq.keys(): is_need_op = False
return if op.type == "c_sync_comm_stream":
op_desc_list = op_desc_seq[rank_id] continue
for var_name in op.output_arg_names:
idx = None if var_name in need_vars:
for index, op in list(enumerate(block.ops)): is_need_op = True
if op.desc.id == reshard_op.desc.id: break
idx = index if is_need_op:
break for var_name in op.output_arg_names:
assert idx is not None, "The op for reshard cannot be found in the rank {} program.".format( actual_need_vars.add(var_name)
rank_id) for var_name in op.input_arg_names:
actual_need_vars.add(var_name)
matched_op = block.ops[idx]
source_tensor = _get_var(var_name, block, program)
for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc): # noqa: F401
if var_name not in HAS_ALLGATHER.keys():
HAS_ALLGATHER[var_name] = []
if not HAS_ALLGATHER[var_name] or op_desc.group not in list(
map(lambda x: x[0], HAS_ALLGATHER[var_name])):
tensor_list, idx_offset = _insert_allgather_op(
block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list]
HAS_ALLGATHER[var_name].append(
[op_desc.group, tensor_name_list])
else:
for item in HAS_ALLGATHER[var_name]:
if op_desc.group == item[0]:
tensor_list = [
program.global_block().vars[var_name]
for var_name in item[1]
]
break
assert tensor_list, "The result of parsing allgather op should not be None."
elif isinstance(op_desc, SendOpDesc):
if var_name not in HAS_SENT.keys():
HAS_SENT[var_name] = []
if op_desc.dst not in HAS_SENT[var_name]:
_insert_send_op(block, idx, source_tensor, op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
HAS_SENT[var_name].append(op_desc.dst)
elif isinstance(op_desc, RecvOpDesc):
if var_name not in HAS_RECV.keys():
HAS_RECV[var_name] = {}
if op_desc.src not in HAS_RECV[var_name].keys():
partition_index = op_desc.partition_index
shape = []
for index in partition_index:
shape.append(index[1] - index[0])
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
dtype=source_tensor.dtype,
type=source_tensor.type)
_insert_recv_op(block, idx, recv_tensor, op_desc.src,
reshard_op.attr('op_role'))
tensor_list.append(recv_tensor)
idx += 1
HAS_RECV[var_name][op_desc.src] = recv_tensor
else:
tensor_list.append(HAS_RECV[var_name][op_desc.src])
elif isinstance(op_desc, ConcatOpDesc):
partition_index_list = op_desc.partition_index_list
idx_list = [idx]
for index, tensor in enumerate(tensor_list):
_concat_partitions_with_op(partition_tensor_list, tensor,
partition_index_list[index], block,
idx_list, reshard_op.attr('op_role'))
idx = idx_list[0]
elif isinstance(op_desc, SliceOpDesc):
assert len(partition_tensor_list) == 1 or not partition_tensor_list
to_slice_tensor = partition_tensor_list[0][0] if len(
partition_tensor_list) == 1 else source_tensor
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = _insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
tensor_attr = TensorDistributedAttribute()
process_mesh = actual_process_mesh
dims_mapping = dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(target_tensor,
tensor_attr)
if op.type == "while":
global while_block_info
# var_reshard_mapping means the while op input need be changed to
if "var_reshard_mapping" not in while_block_info[op.attr(
"sub_block").id].keys():
while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"] = {}
while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = target_tensor.name
# rename op input name according to new name
for op in block.ops:
for name in op.input_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id():
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation. remove_vars = set()
op_process_mesh = op_dist_attr.process_mesh for var_name in startup_block.vars:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( if var_name not in actual_need_vars:
var_name) remove_vars.add(var_name)
if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping: for var in remove_vars:
op.desc._rename_input(name, target_tensor.name) startup_block._remove_var(var)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
"""Remove no need ops in the main program"""
not_remove_op_ref = [
"create_py_reader", "create_double_buffer_reader", "read"
]
global while_block_info
# NOTE: The nested sub block is not be supported now.
remove_block_order = []
for block_idx in while_block_info:
remove_block_order.append(block_idx)
for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
if block_idx not in remove_block_order:
remove_block_order.append(block_idx)
# the sub block should be removed first
for block_idx in remove_block_order:
remove_op_idx = [] remove_op_idx = []
block = auto_parallel_main_prog.blocks[block_idx] vars = startup_block.vars
ops = block.ops for idx, op in enumerate(startup_block.ops):
vars = block.vars is_no_need_op = False
for idx, op in enumerate(ops):
if op.type == "read":
dim_list = []
for var_name in op.output_arg_names:
dim_list.extend(
_get_var(var_name, block, auto_parallel_main_prog)
.shape)
for i in range(idx, -1, -1):
if ops[i].type == "create_py_reader":
ops[i]._set_attr("shape_concat", dim_list)
break
continue
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
if op.type == "c_sync_comm_stream": if op.type == "c_sync_comm_stream":
need_save = [] var_names = []
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_dist_attr_for_program( if var_name in vars:
_get_var(var_name, block, var_names.append(var_name)
auto_parallel_main_prog)).process_mesh if not var_names:
if rank_id in process_mesh.processes:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx) remove_op_idx.append(idx)
continue else:
proto = OpProtoHolder.instance().get_op_proto(op.type)
proto = OpProtoHolder.instance().get_op_proto(op.type) op.desc.set_input(proto.inputs[0].name, var_names)
op.desc.set_input(proto.inputs[0].name, need_save) op.desc.set_output(proto.outputs[0].name, var_names)
op.desc.set_output(proto.outputs[0].name, need_save)
continue continue
# judge the other op whether should be removed. for var_name in op.output_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if var_name not in vars:
if op_dist_attr is not None: is_no_need_op = True
op_process_mesh = op_dist_attr.process_mesh break
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref: if is_no_need_op:
remove_op_idx.append(idx) remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]: for idx in remove_op_idx[::-1]:
block._remove_op(idx) startup_block._remove_op(idx)
def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads): class Resharder:
"""Remove no need vars in the main program""" """
for block_idx, block in enumerate(auto_parallel_main_prog.blocks): Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
remove_vars = set()
ops = block.ops
vars = block.vars
need_vars = set()
for op in ops:
for var_name in op.input_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var_name in op.output_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var in vars:
if var not in need_vars:
remove_vars.add(var)
# change dist_params_grads, the optimize op just in block 0.
if block_idx == 0:
param_grad_map = {}
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param_grad_map[param_name] = grad_name
need_remove_idx = [] Args:
for idx, item in enumerate(dist_params_grads): auto_parallel_main_prog (Program): An auto parallel main program.
if item[0].name not in param_grad_map.keys(): auto_parallel_startup_prog (Program): An auto parallel startup program.
need_remove_idx.append(idx) rank_id (int): The process id.
dist_context (DistributedContext): The distributed context of this rank.
dist_params_grads (list): The list contains the tuple of param and grad.
batch_size (int): The batch size. Default: None.
"""
while_block_info = {}
def __init__(self,
auto_parallel_main_prog,
auto_parallel_startup_prog,
rank_id,
dist_context,
dist_params_grads,
batch_size=None):
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog))
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \
"but got {}.".format(type(auto_parallel_startup_prog))
assert isinstance(rank_id, int), "The type of rank_id should be int, " \
"but got {}.".format(type(rank_id))
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
if batch_size is not None:
assert isinstance(batch_size, int), "The type of batch_size should be int, " \
"but got {}.".format(type(batch_size))
self._auto_parallel_main_prog = auto_parallel_main_prog
self._auto_parallel_startup_prog = auto_parallel_startup_prog
self._rank_id = rank_id
self._dist_context = dist_context
self._dist_params_grads = dist_params_grads
self._batch_size = batch_size
self._has_sent = {}
self._has_recv = {}
self._has_allgather = {}
for idx in need_remove_idx[::-1]: @property
dist_params_grads.pop(idx) def auto_parallel_main_prog(self):
return self._auto_parallel_main_prog
idx = 0 @property
while idx < len(dist_params_grads): def auto_parallel_startup_prog(self):
param_name = dist_params_grads[idx][0].name return self._auto_parallel_startup_prog
grad_name = dist_params_grads[idx][1].name
if grad_name != param_grad_map[param_name]:
dist_params_grads[idx] = (vars[param_name],
vars[param_grad_map[param_name]])
idx += 1
for var in remove_vars: @property
block._remove_var(var) def rank_id(self):
return self._rank_id
def _change_while_op_input_and_output(auto_parallel_main_prog, dist_context):
"""Change while op input and output after the corresponding sub block ops removed"""
global while_block_info
for sub_block_idx in while_block_info:
sub_block = auto_parallel_main_prog.blocks[sub_block_idx]
parent_while_op_id = while_block_info[sub_block_idx]["op_id"]
parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx]
sub_block_op_inputs = set()
sub_block_op_outputs = []
for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name)
for var_name in op.input_arg_names:
sub_block_op_inputs.add(var_name)
# find the while op @property
while_op = None def dist_context(self):
for op in parent_block.ops: return self._dist_context
if op.desc.id() == parent_while_op_id and op.type == "while":
while_op = op
break
assert while_op is not None @property
def dist_params_grads(self):
# find the actual input and output of while op return self._dist_params_grads
proto = OpProtoHolder.instance().get_op_proto(while_op.type)
new_X = []
for var_name in while_op.input("X"):
if var_name in sub_block_op_inputs:
new_X.append(var_name)
assert new_X
while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id,
dist_params_grads):
"""Remove no need vars and ops in the main program."""
_remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id)
_change_while_op_input_and_output(auto_parallel_main_prog, dist_context)
_remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
def remove_no_need_in_startup(auto_parallel_main_prog,
auto_parallel_startup_prog):
"""Remove no need vars and ops in the startup program."""
main_input_vars = set()
main_ops = auto_parallel_main_prog.global_block().ops
for op in main_ops:
for var_name in op.input_arg_names:
main_input_vars.add(var_name)
startup_block = auto_parallel_startup_prog.global_block()
startup_output_vars = set()
startup_ops = startup_block.ops
for op in startup_ops:
# skip c_sync_comm_stream op
if op.type == "c_sync_comm_stream":
continue
for var_name in op.output_arg_names:
startup_output_vars.add(var_name)
need_vars = set()
for var_name in startup_output_vars:
if var_name in main_input_vars:
need_vars.add(var_name)
startup_ops = startup_block.ops
actual_need_vars = set()
for idx, op in enumerate(startup_ops):
is_need_op = False
if op.type == "c_sync_comm_stream":
continue
for var_name in op.output_arg_names:
if var_name in need_vars:
is_need_op = True
break
if is_need_op:
for var_name in op.output_arg_names:
actual_need_vars.add(var_name)
for var_name in op.input_arg_names:
actual_need_vars.add(var_name)
remove_vars = set()
for var_name in startup_block.vars:
if var_name not in actual_need_vars:
remove_vars.add(var_name)
for var in remove_vars:
startup_block._remove_var(var)
remove_op_idx = []
vars = startup_block.vars
for idx, op in enumerate(startup_block.ops):
is_no_need_op = False
if op.type == "c_sync_comm_stream":
var_names = []
for var_name in op.input_arg_names:
if var_name in vars:
var_names.append(var_name)
if not var_names:
remove_op_idx.append(idx)
else:
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, var_names)
op.desc.set_output(proto.outputs[0].name, var_names)
continue
for var_name in op.output_arg_names:
if var_name not in vars:
is_no_need_op = True
break
if is_no_need_op:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
startup_block._remove_op(idx)
def _get_process_meshes(op, program, dist_context):
"""Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = program.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def _is_condition_replicative(op, program, dist_context):
assert op.type == "while"
sub_block = program.blocks[op.attr("sub_block").id]
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
# the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"):
var = _get_var(var_name, sub_block, program)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
tensor_dist_attr = dist_tensor.dist_attr
var_dims_mapping = tensor_dist_attr.dims_mapping
for dim in var_dims_mapping:
if dim != -1:
return False
return True @property
def batch_size(self):
return self._batch_size
@property
def has_sent(self):
return self._has_sent
def _get_op_process_meshes(op, dist_context): @property
process_meshes = [] def has_recv(self):
dist_op = dist_context.get_dist_op_for_program(op) return self._has_recv
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in dist_context.process_meshes:
if set(process_mesh.processes) & (
set(op_process_mesh.processes)
) and len(process_mesh.processes) <= len(op_process_mesh.processes):
process_meshes.append(process_mesh)
# it means the process mesh is not a union when process meshes is null @property
if not process_meshes: def has_allgather(self):
process_meshes.append(op_process_mesh) return self._has_allgather
@staticmethod
def compute_partition_shape(complete_shape, dims_mapping, process_shape):
"""Compute the shape of partition."""
partition_shape = []
for idx, item in enumerate(complete_shape):
if dims_mapping[idx] == -1:
partition_shape.append(item)
else:
partition_shape.append(item // process_shape[dims_mapping[idx]])
return process_meshes return partition_shape
@staticmethod
def compute_process_index(process, process_group, process_shape):
"""Compute the index of process_shape corresponding to the process."""
relative_process = process_group.index(process)
process_index = []
product = reduce(lambda x, y: x * y, process_shape)
for i in range(len(process_shape)):
idx = relative_process // (product // process_shape[i])
product = product // process_shape[i]
relative_process = relative_process - relative_process // product * product
process_index.append(idx)
return process_index
@staticmethod
def compute_partition_index(process, complete_shape, dims_mapping,
process_shape, process_group):
"""Compute the partition index in complete tensor."""
partition_shape = Resharder.compute_partition_shape(
complete_shape, dims_mapping, process_shape)
process_index = Resharder.compute_process_index(process, process_group,
process_shape)
partition_index = []
for i in range(len(complete_shape)):
if dims_mapping[i] == -1:
partition_index.append([0, partition_shape[i]])
else:
partition_index.append([
process_index[dims_mapping[i]] * partition_shape[i],
(process_index[dims_mapping[i]] + 1) * partition_shape[i]
])
return partition_index
@staticmethod
def compute_concat_info(partition_index_x, partition_index_y):
"""Judge whether two partition can be concatenated and compute concatenated partition index."""
differ_count = 0
concat_axis = -1
first_order = 0
new_partition = []
for idx, item in enumerate(partition_index_x):
if item != partition_index_y[idx]:
differ_count += 1
if item[1] == partition_index_y[idx][0] and item[
0] < partition_index_y[idx][1]:
concat_axis = idx
new_partition.append([item[0], partition_index_y[idx][1]])
elif item[0] == partition_index_y[idx][1] and item[
1] > partition_index_y[idx][0]:
first_order = 1
concat_axis = idx
new_partition.append([partition_index_y[idx][0], item[1]])
else:
new_partition.append(item)
def reshard(auto_parallel_main_prog, if differ_count == 1:
auto_parallel_startup_prog, return concat_axis, first_order, new_partition
rank_id, else:
dist_context, return -1, first_order, new_partition
dist_params_grads,
batch_size=None): @staticmethod
""" def compute_complete_shape(slice_shape, process_shape, dims_mapping):
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. """compute the complete shape of the slice tensor with its process mesh and dims mapping"""
complete_shape = []
for idx, item in enumerate(slice_shape):
if dims_mapping[idx] == -1:
complete_shape.append(item)
else:
complete_shape.append(item * process_shape[dims_mapping[idx]])
return complete_shape
Args: @staticmethod
auto_parallel_main_prog (Program): An auto parallel main program. def concat_partitions(partition_index_list, partition_index):
auto_parallel_startup_prog (Program): An auto parallel startup program. """Concat the given partitions without inserting concat op."""
rank_id (int): The process id. if not partition_index_list:
dist_context (DistributedContext): The distributed context of this rank. partition_index_list.append(partition_index)
dist_params_grads (list): The list contains the tuple of param and grad. else:
""" i = 0
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ has_concat = False
"but got {}.".format(type(auto_parallel_main_prog)) while i < len(partition_index_list):
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \ concat_axis, _, new_partition = Resharder.compute_concat_info(
"but got {}.".format(type(auto_parallel_startup_prog)) partition_index_list[i], partition_index)
assert isinstance(rank_id, int), "The type of rank_id should be int, " \ if concat_axis != -1:
"but got {}.".format(type(rank_id)) has_concat = True
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ partition_index_list.pop(i)
"but got {}.".format(type(dist_context)) Resharder.concat_partitions(partition_index_list,
new_partition)
def _is_special_op(op): break
i += 1
if not has_concat:
partition_index_list.append(partition_index)
@staticmethod
def change_while_op_input_and_output(auto_parallel_main_prog, dist_context):
"""Change while op input and output after the corresponding sub block ops removed"""
for sub_block_idx in Resharder.while_block_info:
sub_block = auto_parallel_main_prog.blocks[sub_block_idx]
parent_while_op_id = Resharder.while_block_info[sub_block_idx][
"op_id"]
parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx]
sub_block_op_inputs = set()
sub_block_op_outputs = []
for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name)
for var_name in op.input_arg_names:
sub_block_op_inputs.add(var_name)
# find the while op
while_op = None
for op in parent_block.ops:
if op.desc.id() == parent_while_op_id and op.type == "while":
while_op = op
break
assert while_op is not None
# find the actual input and output of while op
proto = OpProtoHolder.instance().get_op_proto(while_op.type)
new_X = []
for var_name in while_op.input("X"):
if var_name in sub_block_op_inputs:
new_X.append(var_name)
assert new_X
while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
def is_overlapped(self, shape_x, shape_y):
"""Judge whether two partitions intersect on the specified dimension."""
overlapped = False
if (shape_y[0] <= shape_x[0] < shape_y[1]) or (
shape_x[0] <= shape_y[0] < shape_x[1]):
overlapped = True
return overlapped
def is_unshard(self, dims_mapping):
for dim in dims_mapping:
if dim != -1:
return False
return True
def is_special_op(self, op):
global _g_special_ops global _g_special_ops
if op.type in _g_special_ops: if op.type in _g_special_ops:
return True return True
return False return False
global while_block_info def is_condition_replicative(self, op):
for block_idx, block in enumerate(auto_parallel_main_prog.blocks): assert op.type == "while"
if block_idx in while_block_info: sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
if "var_reshard_mapping" in while_block_info[block_idx]: dist_op = self.dist_context.get_dist_op_for_program(op)
var_reshard_mapping = while_block_info[block_idx][ op_dist_attr = dist_op.dist_attr
"var_reshard_mapping"]
for op in block.ops: # the dims mapping of condition tensor should be replicative
for var_name in op.input_arg_names: for var_name in op.input("Condition"):
if var_name in var_reshard_mapping: var = get_var_with_recursion(var_name, sub_block,
op.desc._rename_input(var_name, self.auto_parallel_main_prog)
var_reshard_mapping[var_name]) dist_tensor = self.dist_context.get_dist_tensor_for_program(var)
dist_op = dist_context.get_dist_op_for_program(op) tensor_dist_attr = dist_tensor.dist_attr
op_dist_attr = dist_op.dist_attr var_dims_mapping = tensor_dist_attr.dims_mapping
if op_dist_attr.process_mesh == while_block_info[ for dim in var_dims_mapping:
block_idx]["actual_process_mesh"]: if dim != -1:
dims_mapping = op_dist_attr.get_input_dims_mapping( return False
var_name)
op_dist_attr.set_input_dims_mapping(
var_reshard_mapping[var_name], dims_mapping)
op_dist_attr.set_input_dist_attr(var_name, None)
# the outputs also need to be renamed when the output name is the same with input name return True
for var_name in op.output_arg_names:
if var_name in var_reshard_mapping: def need_reshard(self,
op.desc._rename_output( dist_tensor,
var_name, var_reshard_mapping[var_name]) dist_op,
dist_op = dist_context.get_dist_op_for_program(op) actual_process_mesh,
op_dist_attr = dist_op.dist_attr op_input=True):
if op_dist_attr.process_mesh == while_block_info[ """Judge the tensor whether needs to be resharded."""
block_idx]["actual_process_mesh"]: is_reshard = False
dims_mapping = op_dist_attr.get_output_dims_mapping( tensor_dist_attr = dist_tensor.dist_attr
var_name) tensor_name = dist_tensor.serial_tensor.name
op_dist_attr.set_output_dims_mapping( tensor_dims_mapping = tensor_dist_attr.dims_mapping
var_reshard_mapping[var_name], dims_mapping) tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr.set_output_dist_attr(var_name, op_dist_attr = dist_op.dist_attr
None) op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = actual_process_mesh
idx = 0 if op_input:
while idx < len(block.ops): op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
pre_op_count = len(block.ops) tensor_name)
op = block.ops[idx] if all(
map(lambda x: x is not None, [
if _is_special_op(op): tensor_dims_mapping, tensor_process_mesh,
idx += 1 op_input_dims_mapping, op_process_mesh
])):
# dims_mapping
if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert self.is_unshard(tensor_dims_mapping)
assert self.is_unshard(op_input_dims_mapping)
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError(
"Bool var is not supported reshard.")
# for while op, it should find the process mesh of op actually used the tensor as input
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
process_mesh = dist_op_attr.process_mesh
if process_mesh == op_process_mesh:
is_reshard = True
break
else:
is_reshard = True
else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh
])):
if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
return is_reshard
def get_process_meshes(self, op):
"""Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def get_op_process_meshes(self, op):
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (
set(op_process_mesh.processes)
) and len(process_mesh.processes) <= len(op_process_mesh.processes):
process_meshes.append(process_mesh)
# it means the process mesh is not a union when process meshes is null
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def get_while_op_actual_process_mesh(self, op):
"""Get the while op actual Process mesh corresponding to rank"""
assert op.type == "while"
while_op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue continue
if self.rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break
dist_op = dist_context.get_dist_op_for_program(op) if actual_process_mesh is None and self.rank_id in while_op_process_mesh.processes:
if dist_op is not None: actual_process_mesh = while_op_process_mesh
process_meshes = []
if op.type == "while": assert actual_process_mesh is not None
if not _is_condition_replicative( return actual_process_mesh
op, auto_parallel_main_prog, dist_context):
raise ValueError( def find_op_desc_seq(self, dist_tensor, dist_op, actual_process_mesh):
"Please check the condition due to the dims mapping is not replicative." """
) Find the op description sequence to reshard the source tensor for matching the op requirement.
process_meshes = _get_process_meshes(
op, auto_parallel_main_prog, dist_context) Args:
assert process_meshes dist_tensor (DistributedTensor): A distributed tensor.
if op.attr("sub_block").id not in while_block_info: dist_op (DistributedOperator): A distributed operator.
while_block_info[op.attr("sub_block").id] = {} actual_process_mesh (ProcessMesh): The actual op process mesh.
while_block_info[op.attr("sub_block").id][
"op_id"] = op.desc.id() Returns:
while_block_info[op.attr("sub_block").id][ Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
"actual_process_mesh"] = _get_while_op_actual_process_mesh( process and value is a list containing op description.
op, auto_parallel_main_prog, rank_id, dist_context) """
tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr
target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0:
new_shape = list(source_tensor.shape)
new_shape[0] = self.batch_size
source_tensor.desc.set_shape(new_shape)
complete_shape = Resharder.compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping)
op_desc_seq = {}
# TODO: if the target process group has the same process with source process group
if set(target_process_group).intersection(set(
source_process_group)) and set(target_process_group).difference(
set(source_process_group)):
pass
# in the different process group, it will use send, recv, concat and slice op
elif target_process_group != source_process_group:
partition_process_mapping_list = []
for source_process in source_process_group:
source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \
source_process_shape, source_process_group)
if not partition_process_mapping_list:
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
else: else:
process_meshes = _get_op_process_meshes(op, dist_context) partition_list = list(
input_vars = None [item[0] for item in partition_process_mapping_list])
process_list = list(
[item[1] for item in partition_process_mapping_list])
has_used = list(
[item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index)
process_list[index].append(source_process)
has_used[index].append(False)
else:
partition_process_mapping_list.append([
source_partition_index, [source_process], [False]
])
for target_process in target_process_group:
has_sent = []
target_partition_index = Resharder.compute_partition_index(
target_process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
partition_index_list = []
all_partition_index_list = []
for source_process in source_process_group:
source_partition_index = Resharder.compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
to_send_process = None
if all(_ for _ in list(map(self.is_overlapped, source_partition_index, target_partition_index))) \
and source_partition_index not in has_sent:
idx = list([
item[0] for item in partition_process_mapping_list
]).index(source_partition_index)
has_used = list([
item[2] for item in partition_process_mapping_list
])[idx]
process_list = list([
item[1] for item in partition_process_mapping_list
])[idx]
i = 0
while i < len(has_used):
if not has_used[i]:
to_send_process = process_list[i]
has_used[i] = True
break
i += 1
if i == len(has_used):
has_used = list(map(lambda x: False, has_used))
to_send_process = process_list[0]
has_used[0] = True
assert to_send_process is not None, "Failed to find the send process."
if to_send_process not in op_desc_seq.keys():
op_desc_seq[to_send_process] = []
if target_process not in op_desc_seq.keys():
op_desc_seq[target_process] = []
all_partition_index_list.append(source_partition_index)
# append send and recv op desc
send_op_desc = SendOpDesc(source_partition_index,
target_process)
recv_op_desc = RecvOpDesc(source_partition_index,
to_send_process)
op_desc_seq[to_send_process].append(send_op_desc)
op_desc_seq[target_process].append(recv_op_desc)
has_sent.append(source_partition_index)
Resharder.concat_partitions(partition_index_list,
source_partition_index)
# append concat op desc
op_desc_seq[target_process].append(
ConcatOpDesc(all_partition_index_list))
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
concatenated_partition_index = partition_index_list[0]
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] - item[
0])
slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx)
op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes))
# in the same process group, it will use allgahther and slice op
else:
partition_index_list = []
all_partition_index_list = []
process_index = []
for source_process in source_process_group:
source_partition_index = Resharder.compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
if source_partition_index not in partition_index_list:
partition_index_list.append(source_partition_index)
process_index.append(
[[source_process, ], source_partition_index])
else:
process_index[partition_index_list.index(
source_partition_index)][0].append(source_process)
for i in range(len(process_index[0][0])):
group = []
for j in range(len(process_index)):
group.append(process_index[j][0][i])
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
target_partition_index = Resharder.compute_partition_index(
process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
slices_axes.append(idx)
slice_op_desc = SliceOpDesc(
starts=slice_starts, ends=slice_ends, axes=slices_axes)
op_desc_seq[process] = [AllGatherOpDesc(group=group),
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \
if len(group) > 1 else [slice_op_desc]
return op_desc_seq
def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
actual_process_mesh):
"""Parse op desc sequence and insert op in the block"""
tensor_list = []
partition_tensor_list = []
if self.rank_id not in op_desc_seq.keys():
return
op_desc_list = op_desc_seq[self.rank_id]
idx = None
for index, op in list(enumerate(block.ops)):
if op.desc.id == reshard_op.desc.id:
idx = index
break
assert idx is not None, "The op for reshard cannot be found in the rank {} program.".format(
self.rank_id)
matched_op = block.ops[idx]
source_tensor = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog)
for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc): # noqa: F401
if var_name not in self.has_allgather.keys():
self.has_allgather[var_name] = []
if not self.has_allgather[
var_name] or op_desc.group not in list(
map(lambda x: x[0], self.has_allgather[var_name])):
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list]
self.has_allgather[var_name].append(
[op_desc.group, tensor_name_list])
else:
for item in self.has_allgather[var_name]:
if op_desc.group == item[0]:
tensor_list = [
program.global_block().vars[var_name]
for var_name in item[1]
]
break
assert tensor_list, "The result of parsing allgather op should not be None."
elif isinstance(op_desc, SendOpDesc):
if var_name not in self.has_sent.keys():
self.has_sent[var_name] = []
if op_desc.dst not in self.has_sent[var_name]:
Inserter.insert_send_op(block, idx, source_tensor,
op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
self.has_sent[var_name].append(op_desc.dst)
elif isinstance(op_desc, RecvOpDesc):
if var_name not in self.has_recv.keys():
self.has_recv[var_name] = {}
if op_desc.src not in self.has_recv[var_name].keys():
partition_index = op_desc.partition_index
shape = []
for index in partition_index:
shape.append(index[1] - index[0])
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
dtype=source_tensor.dtype,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src,
reshard_op.attr('op_role'))
tensor_list.append(recv_tensor)
idx += 1
self.has_recv[var_name][op_desc.src] = recv_tensor
else:
tensor_list.append(self.has_recv[var_name][op_desc.src])
elif isinstance(op_desc, ConcatOpDesc):
partition_index_list = op_desc.partition_index_list
idx_list = [idx]
for index, tensor in enumerate(tensor_list):
Inserter.concat_partitions_with_op(
partition_tensor_list, tensor,
partition_index_list[index], block, idx_list,
reshard_op.attr('op_role'))
idx = idx_list[0]
elif isinstance(op_desc, SliceOpDesc):
assert len(
partition_tensor_list) == 1 or not partition_tensor_list
to_slice_tensor = partition_tensor_list[0][0] if len(
partition_tensor_list) == 1 else source_tensor
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = Inserter.insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
tensor_attr = TensorDistributedAttribute()
process_mesh = actual_process_mesh
dims_mapping = self.dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
target_tensor, tensor_attr)
if op.type == "while": if op.type == "while":
input_var_names = op.input("X") # var_reshard_mapping means the while op input need be changed to
if "var_reshard_mapping" not in Resharder.while_block_info[
op.attr("sub_block").id].keys():
Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"] = {}
Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = target_tensor.name
# rename op input name according to new name
for op in block.ops:
for name in op.input_arg_names:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
op)
if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id():
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
def reshard(self):
for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks):
if block_idx in Resharder.while_block_info:
if "var_reshard_mapping" in Resharder.while_block_info[
block_idx]:
var_reshard_mapping = Resharder.while_block_info[block_idx][
"var_reshard_mapping"]
for op in block.ops:
for var_name in op.input_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_input(
var_name, var_reshard_mapping[var_name])
dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
op_dist_attr.set_input_dims_mapping(
var_reshard_mapping[var_name],
dims_mapping)
op_dist_attr.set_input_dist_attr(var_name,
None)
# the outputs also need to be renamed when the output name is the same with input name
for var_name in op.output_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_output(
var_name, var_reshard_mapping[var_name])
dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
op_dist_attr.set_output_dims_mapping(
var_reshard_mapping[var_name],
dims_mapping)
op_dist_attr.set_output_dist_attr(var_name,
None)
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
if self.is_special_op(op):
idx += 1
continue
dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None:
process_meshes = []
if op.type == "while":
if not self.is_condition_replicative(op):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
)
process_meshes = self.get_process_meshes(op)
assert process_meshes
if op.attr("sub_block"
).id not in Resharder.while_block_info:
Resharder.while_block_info[op.attr("sub_block")
.id] = {}
Resharder.while_block_info[op.attr("sub_block").id][
"op_id"] = op.desc.id()
Resharder.while_block_info[op.attr("sub_block").id][
"actual_process_mesh"] = self.get_while_op_actual_process_mesh(
op)
else:
process_meshes = self.get_op_process_meshes(op)
input_vars = None
if op.type == "while":
input_var_names = op.input("X")
else:
input_var_names = op.input_arg_names
idx_offset = 0
for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = get_var_with_recursion(
var_name, block, self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
for process_mesh in process_meshes:
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh):
reshard_op_desc = self.find_op_desc_seq(
dist_tensor, dist_op, process_mesh)
self.parse_op_desc(block, reshard_op_desc,
var_name, op, process_mesh)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else: else:
input_var_names = op.input_arg_names idx += 1
idx_offset = 0
for var_name in op.input_arg_names: # insert send and recv op if output process mesh is different from tensor process mesh
# skip lod_tensor_blocking_queue_0 idx = 0
if var_name == "lod_tensor_blocking_queue_0": # skip reader and ops whose process mesh is union
continue skip_ops = [
var = _get_var(var_name, block, auto_parallel_main_prog) "create_py_reader", "create_double_buffer_reader", "read",
dist_tensor = dist_context.get_dist_tensor_for_program(var) "while", "write_to_array", "read_from_array"
for process_mesh in process_meshes: ]
if dist_tensor is not None and _need_reshard( global _g_special_ops
dist_tensor, dist_op, process_mesh, skip_ops += _g_special_ops
auto_parallel_main_prog, dist_context): while idx < len(block.ops):
reshard_op_desc = find_op_desc_seq( pre_op_count = len(block.ops)
dist_tensor, dist_op, process_mesh, batch_size) op = block.ops[idx]
parse_op_desc(block, rank_id, reshard_op_desc, dist_op = self.dist_context.get_dist_op_for_program(op)
var_name, op, dist_context, if dist_op is not None and op.type not in skip_ops:
auto_parallel_main_prog, process_mesh) for var_name in op.output_arg_names:
var = get_var_with_recursion(
var_name, block, self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
process_mesh = dist_op.dist_attr.process_mesh
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh, False):
for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
if self.rank_id == item:
Inserter.insert_send_op(block, idx + 1, var,
recv_rank,
op.attr('op_role'))
if self.rank_id == recv_rank:
Inserter.insert_recv_op(block, idx + 1, var,
item,
op.attr('op_role'))
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count pre_op_count = cur_op_count
idx = idx + idx_offset + 1 idx = idx + idx_offset + 1
else: else:
idx += 1 idx += 1
# insert send and recv op if output process mesh is different from tensor process mesh # remove no need vars and ops in the main program
idx = 0 Remover.remove_no_need_in_main(self.auto_parallel_main_prog,
# skip reader and ops whose process mesh is union self.dist_context, self.rank_id,
skip_ops = [ self.dist_params_grads)
"create_py_reader", "create_double_buffer_reader", "read", "while",
"write_to_array", "read_from_array"
]
skip_ops += _g_special_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops:
for var_name in op.output_arg_names:
var = _get_var(var_name, block, auto_parallel_main_prog)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
process_mesh = dist_op.dist_attr.process_mesh
if dist_tensor is not None and _need_reshard(
dist_tensor, dist_op, process_mesh,
auto_parallel_main_prog, dist_context, False):
for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
if rank_id == item:
_insert_send_op(block, idx + 1, var, recv_rank,
op.attr('op_role'))
if rank_id == recv_rank:
_insert_recv_op(block, idx + 1, var, item,
op.attr('op_role'))
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
# remove no need vars and ops in the main program # remove no need vars and ops in the startip program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, Remover.remove_no_need_in_startup(self.auto_parallel_main_prog,
dist_params_grads) self.auto_parallel_startup_prog)
# remove no need vars and ops in the startip program # reset some variable when remove operation ended
remove_no_need_in_startup(auto_parallel_main_prog, Resharder.while_block_info = {}
auto_parallel_startup_prog)
...@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -775,19 +775,19 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
def _merge_parameter_with_dist_attr(param_list, dist_attr): def _merge_parameter_with_dist_attr(param_list, dist_attr):
""" Merge parameter with distributed attribute """ """ Merge parameter with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# get the complete shape of the parameter # get the complete shape of the parameter
complete_shape = _compute_complete_shape(param_list[0].shape, process_shape, complete_shape = Resharder.compute_complete_shape(
dims_mapping) param_list[0].shape, process_shape, dims_mapping)
# merge the parameter with dist_attr # merge the parameter with dist_attr
partition_param_list = [] partition_param_list = []
merged_partiton = [] merged_partiton = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group)
index = process_group.index(process) index = process_group.index(process)
if partition_index not in merged_partiton: if partition_index not in merged_partiton:
...@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index, ...@@ -840,7 +840,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
_merge_parameter(partition_param_list, param, partition_index) _merge_parameter(partition_param_list, param, partition_index)
# partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] # partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
""" """
from .reshard import _compute_concat_info from .reshard import Resharder
if len(partition_param_list) == 1: if len(partition_param_list) == 1:
is_complete_data = True is_complete_data = True
...@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index, ...@@ -856,7 +856,7 @@ def _merge_parameter(partition_param_list, param, partition_index,
else: else:
i = 0 i = 0
while i < len(partition_param_list): while i < len(partition_param_list):
concat_axis, first_order, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_param_list[i][1], partition_index) partition_param_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
if first_order == 0: if first_order == 0:
...@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, ...@@ -933,9 +933,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_shape, process_group) process_shape, process_group)
# index: 2 # index: 2
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
rank, complete_shape, dims_mapping, process_shape, process_group) rank, complete_shape, dims_mapping, process_shape, process_group)
sliced_param_index = 0 sliced_param_index = 0
for i, shape in enumerate(complete_shape): for i, shape in enumerate(complete_shape):
...@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape, ...@@ -972,11 +972,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]] # index: [[], [], [2, 4]]
""" """
from .reshard import _compute_partition_index from .reshard import Resharder
split_indices_list = [] split_indices_list = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group)
if split_indices_list: if split_indices_list:
for dim in range(len(partition_index)): for dim in range(len(partition_index)):
......
...@@ -31,7 +31,6 @@ from paddle.distributed import fleet ...@@ -31,7 +31,6 @@ from paddle.distributed import fleet
from paddle.fluid.initializer import NumpyArrayInitializer from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program
from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program
from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static() paddle.enable_static()
...@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase): ...@@ -258,9 +257,6 @@ class TestMLPAutoConvert2(unittest.TestCase):
paddle.seed(2021) paddle.seed(2021)
random.seed(2021) random.seed(2021)
np.random.seed(2021) np.random.seed(2021)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
def tearDown(self): def tearDown(self):
os.remove("./model_state_rank{}.pdmodel".format( os.remove("./model_state_rank{}.pdmodel".format(
......
...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.cost_model import estimate_cost from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase): ...@@ -232,8 +232,9 @@ class TestCostModel(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog( distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(distributed_program, dist_startup_prog, rank_id, resharder = Resharder(distributed_program, dist_startup_prog,
dist_context, dist_params_grads) rank_id, dist_context, dist_params_grads)
resharder.reshard()
dist_program.append(distributed_program) dist_program.append(distributed_program)
cluster = None cluster = None
cost = estimate_cost( cost = estimate_cost(
......
...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer ...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import get_all_process_groups from paddle.distributed.auto_parallel.process_group import get_all_process_groups
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.cluster import Cluster
...@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -502,8 +502,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
partitioned_optimize_ops = parallelizer._apply_optimize( partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads) dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_train_program, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
return dist_train_program, dist_startup_prog return dist_train_program, dist_startup_prog
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -310,8 +310,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -320,9 +321,6 @@ class TestMLPReshard(unittest.TestCase):
self.assertTrue(check_initialization(dist_startup_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_pp_diff_process_mesh(self): def test_mlp_pp_diff_process_mesh(self):
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -331,8 +329,9 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True) train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context) print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
...@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -351,8 +350,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0 rank_id = 0
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# send and recv should not exist in dp scene. # send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -179,8 +179,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# print_program_with_dist_attr(dist_main_prog, dist_context) # print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -213,8 +213,9 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_params_grads) dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -272,8 +273,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context.block_state.parse_forward_blocks(complete_train_program) dist_context.block_state.parse_forward_blocks(complete_train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
dist_context, partitioned_params_grads) rank_id, dist_context, partitioned_params_grads)
resharder.reshard()
# the x should not be slice # the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog)) self.assertTrue(check_allgather(partitioned_main_prog))
......
...@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto ...@@ -29,7 +29,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册