未验证 提交 8624f3b1 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard for auto search (#45002)

* update reshard for auto search

* fix unittest bug

* update dist tensor

* update reshard output

* fix unittests bug

* merge develop
上级 236ad4fc
......@@ -276,8 +276,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
# def del_input_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def del_input_dist_attr(self, name):
del self._inputs_dist_attrs[name]
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
......@@ -287,8 +287,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
# def del_output_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def del_output_dist_attr(self, name):
del self._outputs_dist_attrs[name]
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
......
......@@ -163,7 +163,6 @@ class DistributedTensor:
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
self._local_sizes_map = {}
self._local_offsets_map = {}
self._local_shard_map = {}
self._local_tensor_map = {}
......@@ -223,20 +222,17 @@ class DistributedTensor:
return True
def local_sizes(self, rank=None):
"""Get local sizes of the given rank."""
rank = paddle.distributed.get_rank() if rank is None else rank
local_sizes = None
if rank in self._local_sizes_map.keys():
local_sizes = self._local_sizes_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_sizes_map[rank] = local_sizes
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes(global_sizes,
dims_mapping, topology,
processes, rank,
shard_sizes)
return local_sizes
......@@ -282,7 +278,6 @@ class DistributedTensor:
def new_local_tensor(self, block=None, rank=None, name=None):
"""
Create a new local tensor of serial tensor corresponding to rank.
Args:
block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None.
rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None.
......
......@@ -26,6 +26,11 @@ from ..collective import _get_global_env
from .dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process_group import new_process_group, ProcessGroup, _g_process_group_map
from .cost import build_comm_desc, CommContext
from .cost import AllgatherOpCost, SendOpCost
from .cost import SliceOpCost, SplitOpCost, ConcatOpCost
from .cluster import Cluster
from .utils import print_program_with_dist_attr
# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
......@@ -41,6 +46,7 @@ def get_var_with_recursion(var_name, block, program):
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
return var
......@@ -50,11 +56,19 @@ class AllGatherOpDesc:
Args:
group (list): Process group.
shape (list): The tensor shape.
is_bool (bool): Whether allgather bool data. Default: False.
"""
def __init__(self, group):
def __init__(self, group, shape, is_bool=False):
self._group = group
self._desc = "all_gather"
self._shape = shape
self._is_bool = is_bool
@property
def is_bool(self):
return self._is_bool
@property
def group(self):
......@@ -64,8 +78,12 @@ class AllGatherOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
return self._shape
def __repr__(self):
return f"op: {self._desc}, group: {self._group}."
return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}."
class SendOpDesc:
......@@ -74,13 +92,26 @@ class SendOpDesc:
Args:
partition_index (list): The index of partition in complete tensor.
src (int): The source process to send.
dst (int): The destination process to receive.
is_bool (bool): Whether send bool data. Default: False.
"""
def __init__(self, partition_index, dst):
def __init__(self, partition_index, src, dst, is_bool=False):
self._dst = dst
self._partition_index = partition_index
self._desc = "send"
self._shape = []
self._is_bool = is_bool
self._src = src
@property
def src(self):
return self._src
@property
def is_bool(self):
return self._is_bool
@property
def partition_index(self):
......@@ -94,8 +125,15 @@ class SendOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
if not self._shape:
for item in self.partition_index:
self._shape.append(item[1] - item[0])
return self._shape
def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}."
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}."
class RecvOpDesc:
......@@ -105,12 +143,25 @@ class RecvOpDesc:
Args:
partition_index (list): The index of partition in complete tensor.
src (int): The source process to send.
dst (int): The destination process to receive.
is_bool (bool): Whether receive bool data. Default: False.
"""
def __init__(self, partition_index, src):
def __init__(self, partition_index, src, dst, is_bool=False):
self._src = src
self._partition_index = partition_index
self._desc = "recv"
self._shape = []
self._is_bool = is_bool
self._dst = dst
@property
def dst(self):
return self._dst
@property
def is_bool(self):
return self._is_bool
@property
def partition_index(self):
......@@ -124,8 +175,15 @@ class RecvOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
if not self._shape:
for item in self.partition_index:
self._shape.append(item[1] - item[0])
return self._shape
def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, src: {self._src}."
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}."
class SliceOpDesc:
......@@ -133,16 +191,18 @@ class SliceOpDesc:
Describe the slice op in the reshard phase.
Args:
starts (list): It represents starting indices of corresponding axis in ``axes``.
ends (list): It represents ending indices of corresponding axis in ``axes``.
axes (list): Axes that `starts` and `ends` apply to .
starts (list): It represents start indices of corresponding axis in ``axes``.
ends (list): It represents end indices of corresponding axis in ``axes``.
axes (list): Axes that `starts` and `ends` apply to.
shape (list): The shape of the tensor to be sliced.
"""
def __init__(self, starts, ends, axes):
def __init__(self, starts, ends, axes, shape=None):
self._starts = starts
self._ends = ends
self._axes = axes
self._desc = "slice"
self._shape = shape
@property
def starts(self):
......@@ -160,8 +220,15 @@ class SliceOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
return self._shape
def __repr__(self):
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}."
if self._shape is not None:
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}, shape: {self._shape}."
else:
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}."
class ConcatOpDesc:
......@@ -192,36 +259,84 @@ class Inserter:
"""Insert op required in the reshard process."""
@staticmethod
def insert_send_op(block, idx, tensor, dst, op_role):
def insert_cast_op(block, idx, tensor, op_role, tensor_type):
# to avoid name conflict with framework
new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["cast@RESHARD", 'tmp']))
out = block.create_var(name=new_var_name,
dtype=tensor_type,
type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx,
type='cast',
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'in_dtype': tensor.dtype,
'out_dtype': out.dtype,
'op_role': op_role
})
return out
@staticmethod
def insert_send_op(block, idx, tensor, src, dst, op_role):
"""Insert send op into block at the given index."""
op_type = 'send_v2'
# use pair comm group
process_group = new_process_group([src, dst])
block._insert_op(idx,
type=op_type,
inputs={'X': [tensor]},
attrs={
'ring_id': 0,
'peer': dst,
'ring_id': process_group.id,
'peer': process_group.ranks.index(dst),
'use_calc_stream': True,
'op_role': op_role
'op_role': op_role,
'dynamic_shape': True
})
@staticmethod
def insert_recv_op(block, idx, tensor, src, op_role):
def insert_recv_op(block, idx, tensor, src, dst, op_role):
"""Insert recv op into block at the given index."""
op_type = 'recv_v2'
# use pair group
process_group = new_process_group([src, dst])
block._insert_op(idx,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': 0,
'peer': src,
'ring_id': process_group.id,
'peer': process_group.ranks.index(src),
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': True,
'op_role': op_role
'op_role': op_role,
'dynamic_shape': True
})
@staticmethod
def insert_reset_lod_op(block, idx, X, Y, op_role):
"""Insert reset_lod op into block at the given index."""
new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["reset_lod@RESHARD", 'tmp']))
reset_lod_out = block.create_var(name=new_var_name,
shape=X.shape,
type=X.type,
dtype=X.dtype,
lod_level=X.lod_level)
block._insert_op(idx,
type="lod_reset",
inputs={
'X': X,
'Y': Y
},
outputs={'Out': reset_lod_out},
attrs={'op_role': op_role})
return reset_lod_out
@staticmethod
def insert_concat_op(block, idx, tensors, axis, op_role):
"""Insert concat op into block at the given block."""
......@@ -229,10 +344,18 @@ class Inserter:
attrs = {}
attrs['axis'] = axis
attrs['op_role'] = op_role
helper = LayerHelper('concat', **locals())
# to avoid name conflict with framework
helper = LayerHelper('concat@RESHARD', **locals())
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensors[0].dtype,
shape=None,
lod_level=tensors[0].lod_level,
type=tensors[0].type,
persistable=False,
stop_gradient=False)
block._insert_op(idx,
type='concat',
inputs=inputs,
......@@ -244,37 +367,117 @@ class Inserter:
def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_role):
"""Insert slice op into block at the given block."""
inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
attrs = {
"axes": axes,
"starts": starts,
"ends": ends,
"infer_flags": infer_flags,
'op_role': op_role
}
helper = LayerHelper('slice', **locals())
out = block.create_var(name=new_var_name,
dtype=tensor.dtype,
type=tensor.type)
block._insert_op(idx,
type="slice",
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
return out
# This is a hack to insert split op to get slice tensor
# 1. [128, 128] => [64, 128]: split
# 2. [128, 128] => [128, 128]: assign
# 3. [128, 128] => [64, 64]: slice, it will replaced by multi split
global_shape = tensor.shape
slice_shape = [ends[i] - starts[i] for i in range(len(starts))]
diff_dims = []
for index, item in enumerate(slice_shape):
if item != global_shape[index]:
diff_dims.append(index)
# use assign
if len(diff_dims) == 0:
out = block.create_var(name=new_var_name,
dtype=tensor.dtype,
type=tensor.type,
shape=slice_shape,
lod_level=tensor.lod_level)
inputs = {'X': [tensor]}
outputs = {"Out": [out]}
attrs = {"in_place": False}
block._insert_op(idx,
type="assign",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return out
# use split once
elif len(diff_dims) == 1:
diff_dim = diff_dims[0]
num_or_sections = global_shape[diff_dim] // slice_shape[diff_dim]
axis = diff_dim
cur_idx = starts[diff_dim] // slice_shape[diff_dim]
input_shape = global_shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role}
new_shape = []
for index, item in enumerate(tensor.shape):
if index != axis:
new_shape.append(item)
else:
new_shape.append(item // num_or_sections)
with paddle.static.program_guard(block.program):
outs = [
block.create_var(name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(
['split@RESHARD', 'tmp'])),
dtype=tensor.dtype,
shape=None,
type=tensor.type,
persistable=False,
lod_level=tensor.lod_level,
stop_gradient=False)
for i in range(num_or_sections)
]
out = outs[cur_idx]
op = block._insert_op(idx,
type="split",
inputs=inputs,
outputs={'Out': outs},
attrs=attrs)
return out
# use slice
else:
inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
attrs = {
"axes": axes,
"starts": starts,
"ends": ends,
"infer_flags": infer_flags,
'op_role': op_role
}
out = block.create_var(name=new_var_name,
dtype=tensor.dtype,
type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx,
type="slice",
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
return out
@staticmethod
def insert_split_op(block, idx, tensor, num_or_sections, op_role):
def insert_split_op(block, idx, tensor, num_or_sections, op_role, axis=0):
"""Insert split op into block at the given index."""
helper = LayerHelper('split', **locals())
helper = LayerHelper('split@RESHARD', **locals())
input_shape = tensor.shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role}
attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role}
new_shape = []
for index, item in enumerate(tensor.shape):
if index != axis:
new_shape.append(item)
else:
new_shape.append(item // num_or_sections)
with paddle.static.program_guard(block.program):
outs = [
helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) for i in range(num_or_sections)
block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False) for i in range(num_or_sections)
]
block._insert_op(idx,
type="split",
......@@ -286,9 +489,18 @@ class Inserter:
@staticmethod
def insert_fill_constant_op(block, idx, op_role):
"""Insert fill constant op into block at the given index."""
helper = LayerHelper("fill_constant", **locals())
# to avoid name conflict with framework
helper = LayerHelper('fill_constant@RESHARD', **locals())
# use paddle.int64 as dtype
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(dtype="int32")
out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=paddle.int64,
shape=None,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
inputs = {}
attrs = {'force_cpu': False}
attrs['str_value'] = str(int("1"))
......@@ -342,10 +554,18 @@ class Inserter:
# insert c_allgather op
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
# to avoid name conflict with framework
helper = LayerHelper(op_type + "@RESHARD", **locals())
with paddle.static.program_guard(block.program):
allgather_out = helper.create_variable_for_type_inference(
dtype=tensor.dtype)
allgather_out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False)
block._insert_op(idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
......@@ -620,12 +840,14 @@ class Resharder:
batch_size=None):
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog))
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \
"but got {}.".format(type(auto_parallel_startup_prog))
if auto_parallel_startup_prog is not None:
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program or None, " \
"but got {}.".format(type(auto_parallel_startup_prog))
assert isinstance(rank_id, int), "The type of rank_id should be int, " \
"but got {}.".format(type(rank_id))
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
if batch_size is not None:
assert isinstance(batch_size, int), "The type of batch_size should be int, " \
"but got {}.".format(type(batch_size))
......@@ -639,6 +861,8 @@ class Resharder:
self._has_sent = {}
self._has_recv = {}
self._has_allgather = {}
# to avoid reshard repeatly
self._has_resharded = {}
@property
def auto_parallel_main_prog(self):
......@@ -798,7 +1022,10 @@ class Resharder:
for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
if dist_op or (op.type == "slice" and not dist_op) or (
op.type == "split"
and not dist_op) or (op.type == "assign"
and not dist_op):
for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name)
......@@ -812,7 +1039,8 @@ class Resharder:
while_op = op
break
assert while_op is not None
if while_op is None:
continue
# find the actual input and output of while op
proto = OpProtoHolder.instance().get_op_proto(while_op.type)
......@@ -821,13 +1049,15 @@ class Resharder:
if var_name in sub_block_op_inputs:
new_X.append(var_name)
assert new_X
new_X.sort()
while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
new_Out.append(output_name)
if output_name not in new_Out:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
......@@ -870,120 +1100,72 @@ class Resharder:
return True
def need_reshard(self,
dist_tensor,
dist_op,
actual_process_mesh,
op_input=True):
def need_reshard(self, dist_tensor, dist_attr, op_input=True, dist_op=None):
"""Judge the tensor whether needs to be resharded."""
is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name
tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = actual_process_mesh
# dist_attr is [process_mesh, dims_mapping] and process_mesh is not a union
op_process_mesh = dist_attr[0]
if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name)
op_input_dims_mapping = dist_attr[1]
if all(
map(lambda x: x is not None, [
map(lambda x: x, [
tensor_dims_mapping, tensor_process_mesh,
op_input_dims_mapping, op_process_mesh
])):
# dims_mapping
# judge whether need reshard by dims_mapping
if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert self.is_unshard(tensor_dims_mapping)
assert self.is_unshard(op_input_dims_mapping)
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError(
"Bool var is not supported reshard.")
# for while op, it should find the process mesh of op actually used the tensor as input
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
process_mesh = dist_op_attr.process_mesh
if process_mesh == op_process_mesh:
is_reshard = True
break
if tensor_process_mesh not in self.dist_context.process_meshes:
# assert whether -1 when union.
for item in tensor_dims_mapping:
if item != -1:
raise ValueError(
"The dim must be -1 when tensor process mesh is a union."
)
# tensor process_mesh: [0, 1, 2, 3], dims_mapping: [-1, -1]
# op process_mesh: [4, 5], dims_mapping: [0, -1]
# reshard is not supported such as above
if not is_reshard:
return is_reshard
else:
is_reshard = True
raise ValueError(
"it is not supported that tensor process mesh is a union and needs reshard."
)
is_reshard = True
# judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh:
is_reshard = True
else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
op_output_dims_mapping = dist_attr[1]
if all(
map(lambda x: x is not None, [
map(lambda x: x, [
tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh
])):
if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
if tensor_process_mesh != op_process_mesh:
is_reshard = True
return is_reshard
def get_process_meshes(self, op):
"""Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def get_op_process_meshes(self, op):
"""Get sub process meshes of the given op if op process mesh is a union."""
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (set(
op_process_mesh.processes)) and len(
process_mesh.processes) <= len(
process_mesh.processes) < len(
op_process_mesh.processes):
process_meshes.append(process_mesh)
......@@ -993,39 +1175,14 @@ class Resharder:
return process_meshes
def get_while_op_actual_process_mesh(self, op):
"""Get the while op actual Process mesh corresponding to rank"""
assert op.type == "while"
while_op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue
if self.rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break
if actual_process_mesh is None and self.rank_id in while_op_process_mesh.processes:
actual_process_mesh = while_op_process_mesh
assert actual_process_mesh is not None
return actual_process_mesh
def find_op_desc_seq(self, dist_tensor, dist_op, actual_process_mesh):
def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
dist_attr (list): A list contains process_mesh and dims_mapping such as [process_mesh, dims_mapping].
serial (bool): If serial is true, the dist tensor and dist op come from serial program. Otherwise, they come from auto program.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
......@@ -1034,24 +1191,26 @@ class Resharder:
tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr
target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_mesh = dist_attr[0]
target_dims_mapping = dist_attr[1]
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0:
assert source_tensor.shape[0] == -1
new_shape = list(source_tensor.shape)
new_shape[0] = self.batch_size
source_tensor.desc.set_shape(new_shape)
complete_shape = Resharder.compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping)
source_tensor.shape, source_process_shape,
source_dims_mapping) if not serial else source_tensor.shape
op_desc_seq = {}
# TODO: if the target process group has the same process with source process group
......@@ -1060,13 +1219,14 @@ class Resharder:
set(source_process_group)):
pass
# in the different process group, it will use send, recv, concat and slice op
elif target_process_group != source_process_group:
partition_process_mapping_list = []
for source_process in source_process_group:
# get partition index of source process
source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \
source_process_shape, source_process_group)
if not partition_process_mapping_list:
# the item in partition_process_mapping_list is source_partition_index, which processes and whether has been used
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
else:
......@@ -1076,6 +1236,7 @@ class Resharder:
[item[1] for item in partition_process_mapping_list])
has_used = list(
[item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index)
process_list[index].append(source_process)
......@@ -1085,6 +1246,7 @@ class Resharder:
[source_partition_index, [source_process], [False]])
for target_process in target_process_group:
# has_sent means the source_partition_index has been sent to target_process
has_sent = []
target_partition_index = Resharder.compute_partition_index(
target_process, complete_shape, target_dims_mapping,
......@@ -1114,6 +1276,7 @@ class Resharder:
has_used[i] = True
break
i += 1
if i == len(has_used):
has_used = list(map(lambda x: False, has_used))
to_send_process = process_list[0]
......@@ -1127,10 +1290,16 @@ class Resharder:
all_partition_index_list.append(source_partition_index)
# append send and recv op desc
is_bool = (
dist_tensor.serial_tensor.dtype == paddle.bool)
send_op_desc = SendOpDesc(source_partition_index,
target_process)
to_send_process,
target_process,
is_bool=is_bool)
recv_op_desc = RecvOpDesc(source_partition_index,
to_send_process)
to_send_process,
target_process,
is_bool=is_bool)
op_desc_seq[to_send_process].append(send_op_desc)
op_desc_seq[target_process].append(recv_op_desc)
has_sent.append(source_partition_index)
......@@ -1146,16 +1315,24 @@ class Resharder:
slice_ends = []
slices_axes = []
concatenated_partition_index = partition_index_list[0]
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] -
item[0])
slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes))
SliceOpDesc(slice_starts,
slice_ends,
slices_axes,
shape=to_slice_tensor_shape))
# in the same process group, it will use allgahther and slice op
# in the same process group, it will use allgahther and slice op.
else:
# NOTE: It just supports even partition scene.
partition_index_list = []
all_partition_index_list = []
process_index = []
......@@ -1191,17 +1368,21 @@ class Resharder:
slice_ends.append(item[1])
slices_axes.append(idx)
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc(starts=slice_starts,
ends=slice_ends,
axes=slices_axes)
op_desc_seq[process] = [AllGatherOpDesc(group=group),
axes=slices_axes,
shape=to_slice_tensor_shape)
allgather_shape = None if not serial else dist_tensor.local_sizes(
rank=process)
op_desc_seq[process] = [AllGatherOpDesc(group=group, shape=allgather_shape, is_bool=(source_tensor.dtype == paddle.bool)),
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \
if len(group) > 1 else [slice_op_desc]
return op_desc_seq
def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
actual_process_mesh):
dist_attr):
"""Parse op desc sequence and insert op in the block"""
tensor_list = []
partition_tensor_list = []
......@@ -1226,13 +1407,32 @@ class Resharder:
self.has_allgather[var_name] = []
if not self.has_allgather[var_name] or op_desc.group not in list(
map(lambda x: x[0], self.has_allgather[var_name])):
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list]
self.has_allgather[var_name].append(
[op_desc.group, tensor_name_list])
if op_desc.is_bool:
# for bool data allgather, cast to int64 -> allgather -> cast bool
out_cast = Inserter.insert_cast_op(
block, idx, source_tensor,
reshard_op.attr('op_role'), paddle.int64)
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx + 1, out_cast, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = []
for var in tensor_list:
out_cast = Inserter.insert_cast_op(
block, idx, var, reshard_op.attr('op_role'),
paddle.bool)
tensor_name_list.append(out_cast.name)
idx += 1
self.has_allgather[var_name].append(
[op_desc.group, tensor_name_list])
else:
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list]
self.has_allgather[var_name].append(
[op_desc.group, tensor_name_list])
else:
for item in self.has_allgather[var_name]:
if op_desc.group == item[0]:
......@@ -1249,10 +1449,19 @@ class Resharder:
if var_name not in self.has_sent.keys():
self.has_sent[var_name] = []
if op_desc.dst not in self.has_sent[var_name]:
Inserter.insert_send_op(block, idx, source_tensor,
op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
if op_desc.is_bool:
out_cast = Inserter.insert_cast_op(
block, idx, source_tensor,
reshard_op.attr('op_role'), paddle.int64)
Inserter.insert_send_op(block, idx + 1, out_cast,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
idx += 2
else:
Inserter.insert_send_op(block, idx, source_tensor,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
self.has_sent[var_name].append(op_desc.dst)
elif isinstance(op_desc, RecvOpDesc):
......@@ -1263,17 +1472,58 @@ class Resharder:
shape = []
for index in partition_index:
shape.append(index[1] - index[0])
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
dtype=source_tensor.dtype,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src,
reshard_op.attr('op_role'))
tensor_list.append(recv_tensor)
idx += 1
self.has_recv[var_name][op_desc.src] = recv_tensor
if op_desc.is_bool:
# for bool data, recv int64 -> cast to bool
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
lod_level=source_tensor.lod_level,
dtype=paddle.int64,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
out_cast = Inserter.insert_cast_op(
block, idx + 1, recv_tensor,
reshard_op.attr('op_role'), paddle.bool)
tensor_list.append(out_cast)
idx += 2
self.has_recv[var_name][op_desc.src] = out_cast
else:
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
lod_level=source_tensor.lod_level,
dtype=source_tensor.dtype,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
# for lod tensor, need reset lod after received
if recv_tensor.lod_level != 0:
set_lod = False
# use data lod to reset tensor lod
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == recv_tensor.lod_level:
reset_lod_out = Inserter.insert_reset_lod_op(
block, idx + 1, recv_tensor,
tmp_var, reshard_op.attr('op_role'))
tensor_list.append(reset_lod_out)
idx += 2
self.has_recv[var_name][
op_desc.src] = reset_lod_out
set_lod = True
break
if set_lod:
break
assert set_lod is True
else:
tensor_list.append(recv_tensor)
idx += 1
self.has_recv[var_name][op_desc.src] = recv_tensor
else:
tensor_list.append(self.has_recv[var_name][op_desc.src])
......@@ -1303,188 +1553,506 @@ class Resharder:
new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
process_mesh = dist_attr[0]
dims_mapping = dist_attr[1]
tensor_attr = TensorDistributedAttribute()
process_mesh = actual_process_mesh
dims_mapping = self.dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
target_tensor, tensor_attr)
if op.type == "while":
if matched_op.type == "while":
# var_reshard_mapping means the while op input need be changed to
if "var_reshard_mapping" not in Resharder.while_block_info[
op.attr("sub_block").id].keys():
Resharder.while_block_info[op.attr(
"sub_block").id]["var_reshard_mapping"] = {}
if var_name not in Resharder.while_block_info[op.attr(
"sub_block").id]["var_reshard_mapping"].keys():
Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = []
Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = target_tensor.name
"var_reshard_mapping"][var_name].append(
[dist_attr, target_tensor.name])
# rename op input name according to new name
for op in block.ops:
# just for while op
while_op_X_append = []
for name in op.input_arg_names:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
op)
if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id():
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
continue
if matched_op.type == "while":
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping)
if old_name in op_dist_attr._inputs_dist_attrs:
op_dist_attr.del_input_dist_attr(
old_name)
while_op_X_append.append(new_name)
continue
else:
op.desc._rename_input(
name, target_tensor.name)
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping)
op_dist_attr.del_input_dist_attr(old_name)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name)
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
new_name, dims_mapping)
op_dist_attr.del_input_dist_attr(old_name)
def reshard(self):
for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks):
if block_idx in Resharder.while_block_info:
if "var_reshard_mapping" in Resharder.while_block_info[
block_idx]:
var_reshard_mapping = Resharder.while_block_info[block_idx][
"var_reshard_mapping"]
for op in block.ops:
for var_name in op.input_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_input(
var_name, var_reshard_mapping[var_name])
dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
op_dist_attr.set_input_dims_mapping(
var_reshard_mapping[var_name],
dims_mapping)
op_dist_attr.set_input_dist_attr(
var_name, None)
# the outputs also need to be renamed when the output name is the same with input name
for var_name in op.output_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_output(
var_name, var_reshard_mapping[var_name])
dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
op_dist_attr.set_output_dims_mapping(
var_reshard_mapping[var_name],
dims_mapping)
op_dist_attr.set_output_dist_attr(
var_name, None)
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
if self.is_special_op(op):
idx += 1
continue
# for while op, the input X should reset
if while_op_X_append:
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name,
op.input("X") + while_op_X_append)
def _get_while_op_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
input_attrs = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
dist_attr = dist_op.dist_attr
for name in op.input_arg_names:
if name == var_name:
process_mesh = dist_attr.process_mesh
input_dims_mapping = dist_attr.get_input_dims_mapping(
var_name)
has_exist = False
for input_attr in input_attrs:
if process_mesh == input_attr[
0] and input_dims_mapping == input_attr[1]:
has_exist = True
break
if not has_exist:
input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs
def _get_common_op_input_attrs(self, op, var_name):
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
dist_attr = dist_op.dist_attr
op_process_mesh = dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (set(
op_process_mesh.processes)) and len(
process_mesh.processes) < len(
op_process_mesh.processes):
process_meshes.append(process_mesh)
# it means that the process mesh is not a union when process meshes is none
if not process_meshes:
process_meshes.append(op_process_mesh)
input_dims_mapping = dist_attr.get_input_dims_mapping(var_name)
input_attrs = []
for process_mesh in process_meshes:
input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs
def get_op_input_attrs(self, op, var_name):
op_input_attrs = []
dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None:
process_meshes = []
if op.type == "while":
if not self.is_condition_replicative(op):
if op.type == "while":
op_input_attrs = self._get_while_op_input_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
assert op_input_attrs
return op_input_attrs
def _remove_global_process_mesh(self):
"""Remove global process mesh from dist_context.process_meshes"""
processes = set()
process_mesh_count = len(self.dist_context.process_meshes)
if process_mesh_count > 1:
global_process_mesh_idx = None
for process_mesh in self.dist_context.process_meshes:
for process in process_mesh.processes:
processes.add(process)
for idx, process_mesh in enumerate(
self.dist_context.process_meshes):
if len(set(process_mesh.processes)) == len(processes):
global_process_mesh_idx = idx
break
if global_process_mesh_idx is not None:
self.dist_context.process_meshes.pop(idx)
def _change_subblock_op_input_and_output(self, block_idx, block):
if "var_reshard_mapping" in Resharder.while_block_info[block_idx]:
var_reshard_mapping = Resharder.while_block_info[block_idx][
"var_reshard_mapping"]
for op in block.ops:
for var_name in op.input_arg_names:
if var_name in var_reshard_mapping:
# in while sub block, the union process mesh is not split before reshard sub block
dist_op = self.dist_context.get_dist_op_for_program(op)
dist_attr = dist_op.dist_attr
target_name = None
for item in var_reshard_mapping[var_name]:
if dist_attr.process_mesh == item[0][
0] and dist_attr.get_input_dims_mapping(
var_name) == item[0][1]:
target_name = item[1]
break
if target_name is None:
continue
else:
op.desc._rename_input(var_name, target_name)
dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr
old_name = var_name
new_name = target_name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.del_input_dist_attr(old_name)
# the outputs also need to be renamed when the output name is the same with input name in inplace op
for var_name in op.output_arg_names:
# if the tensor has been resharded multiply, it is not supported now.
if var_name in var_reshard_mapping:
if len(var_reshard_mapping[var_name]) > 1:
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
"The scene is not supported that the output is inplaced and the tensor has been resharded multiply when as input."
)
process_meshes = self.get_process_meshes(op)
assert process_meshes
if op.attr("sub_block"
).id not in Resharder.while_block_info:
Resharder.while_block_info[op.attr(
"sub_block").id] = {}
Resharder.while_block_info[op.attr(
"sub_block").id]["op_id"] = op.desc.id()
Resharder.while_block_info[op.attr("sub_block").id][
"actual_process_mesh"] = self.get_while_op_actual_process_mesh(
op)
else:
process_meshes = self.get_op_process_meshes(op)
input_vars = None
if op.type == "while":
input_var_names = op.input("X")
else:
input_var_names = op.input_arg_names
idx_offset = 0
for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = get_var_with_recursion(
var_name, block, self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
for process_mesh in process_meshes:
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh):
reshard_op_desc = self.find_op_desc_seq(
dist_tensor, dist_op, process_mesh)
self.parse_op_desc(block, reshard_op_desc,
var_name, op, process_mesh)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
target_name = var_reshard_mapping[var_name][0][1]
op.desc._rename_output(var_name, target_name)
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
old_name = var_name
new_name = target_name
assert old_name != new_name
op_output_dist_attr = op_dist_attr.get_output_dist_attr(
old_name)
op_dist_attr.set_output_dist_attr(
new_name, op_output_dist_attr)
op_dist_attr.del_output_dist_attr(old_name)
def _reshard_input(self, block):
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
if self.is_special_op(op):
idx += 1
continue
dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None:
op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while":
if not self.is_condition_replicative(op):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
)
if op.attr(
"sub_block").id not in Resharder.while_block_info:
Resharder.while_block_info[op.attr("sub_block").id] = {}
Resharder.while_block_info[op.attr(
"sub_block").id]["op_id"] = op.desc.id()
if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X")
else:
idx += 1
input_var_names = op.input_arg_names
# to avoid while op X order different
input_var_names.sort()
idx_offset = 0
for var_name in input_var_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
# judge whether union tensor dims_mapping all -1
is_union_process_mesh_tensor = False
if dist_tensor.dist_attr.process_mesh not in self.dist_context.process_meshes and self.dist_context.process_meshes:
is_union_process_mesh_tensor = True
assert dist_tensor.dist_attr.dims_mapping.count(
-1) == len(dist_tensor.dist_attr.dims_mapping)
op_input_attrs = self.get_op_input_attrs(op, var_name)
for input_attr in op_input_attrs:
input_process_mesh = None
# deal with union tensor
if is_union_process_mesh_tensor:
# if op process mesh is subset of union tensor process mesh, need no reshard
if set(input_attr[0].processes) <= set(
dist_tensor.dist_attr.process_mesh.processes
):
continue
# insert send and recv op if output process mesh is different from tensor process mesh
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read",
"while", "write_to_array", "read_from_array"
]
global _g_special_ops
skip_ops += _g_special_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops:
for var_name in op.output_arg_names:
var = get_var_with_recursion(
var_name, block, self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
process_mesh = dist_op.dist_attr.process_mesh
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh, False):
for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
if self.rank_id == item:
Inserter.insert_send_op(
block, idx + 1, var, recv_rank,
op.attr('op_role'))
if self.rank_id == recv_rank:
Inserter.insert_recv_op(
block, idx + 1, var, item,
op.attr('op_role'))
dist_tensor, input_attr):
reshard_op_desc = self.find_op_desc_seq(
dist_tensor, input_attr)
self.parse_op_desc(block, reshard_op_desc, var_name,
op, input_attr)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
idx = idx + idx_offset + 1
else:
idx += 1
def _hadnle_recv(self, block, idx, var, op, send_rank, recv_rank):
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
if var.dtype == paddle.bool:
recv_cast_out = block.create_var(
name=unique_name.generate(var.name + "@recv"),
shape=var.shape,
lod_level=var.lod_level,
dtype=paddle.int64,
type=var.type)
Inserter.insert_recv_op(block, idx + 1,
recv_cast_out, send_rank, recv_rank,
op.attr('op_role'))
reset_lod_out = None
if var.lod_level != 0:
set_lod = False
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == var.lod_level:
reset_lod_out = block.create_var(
name=unique_name.generate(var.name +
"@RESETLOD"),
shape=recv_cast_out.shape,
type=recv_cast_out.type,
dtype=recv_cast_out.dtype,
lod_level=recv_cast_out.lod_level)
idx += 1
block._insert_op(
idx,
type="lod_reset",
inputs={
'X': recv_cast_out,
'Y': tmp_var
},
outputs={'Out': reset_lod_out},
attrs={'op_role': op.attr("op_role")})
set_lod = True
break
if set_lod:
break
assert set_lod is True
# cast int64 to bool
block._insert_op(idx + 2,
type='cast',
inputs={
'X': [recv_cast_out] if
reset_lod_out is None else [reset_lod_out]
},
outputs={'Out': [var]},
attrs={
'in_dtype': recv_cast_out.dtype,
'out_dtype': var.dtype,
'op_role': op.attr('op_role')
})
else:
if var.lod_level != 0:
recv_out = block.create_var(
name=unique_name.generate(var.name + "@recv"),
shape=var.shape,
lod_level=var.lod_level,
dtype=var.int64,
type=var.type)
Inserter.insert_recv_op(block, idx + 1, recv_out, send_rank,
recv_rank, op.attr('op_role'))
set_lod = False
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == var.lod_level:
idx += 1
block._insert_op(
idx,
type="lod_reset",
inputs={
'X': recv_out,
'Y': tmp_var
},
outputs={'Out': var},
attrs={'op_role': op.attr("op_role")})
set_lod = True
break
if set_lod:
break
assert set_lod is True
else:
idx += 1
Inserter.insert_recv_op(block, idx + 1, var, send_rank,
recv_rank, op.attr('op_role'))
def _handle_send(self, block, idx, var, op, send_rank, recv_rank):
if var.dtype == paddle.bool:
cast_out = Inserter.insert_cast_op(block, idx + 1, var,
op.attr('op_role'), paddle.int64)
Inserter.insert_send_op(block, idx + 2, cast_out, send_rank,
recv_rank, op.attr('op_role'))
else:
Inserter.insert_send_op(block, idx + 1, var, send_rank, recv_rank,
op.attr('op_role'))
def _reshard_output(self, block):
# insert send and recv op if output process mesh is different from tensor process mesh
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while",
"write_to_array", "read_from_array"
]
global _g_special_ops
skip_ops += _g_special_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops:
idx_offset = 0
for var_name in op.output_arg_names:
var = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
tensor_process_mesh = dist_tensor.dist_attr.process_mesh
output_attr = [
dist_op.dist_attr.process_mesh,
dist_op.dist_attr.get_output_dims_mapping(var_name)
]
if dist_tensor is not None and self.need_reshard(
dist_tensor, output_attr, False):
tensor_processes = set(
tensor_process_mesh.processes) - (
set(tensor_process_mesh.processes)
& set(output_attr[0].processes))
if tensor_processes:
if len(tensor_processes) != len(
output_attr[0].processes):
if dist_tensor.dist_attr.dims_mapping.count(
-1) != len(
dist_tensor.dist_attr.dims_mapping
) or output_attr[1].count(-1) != len(
output_attr[1]):
raise ValueError(
"The dims_mapping must be -1")
else:
for index, tensor_process in enumerate(
tensor_processes):
recv_rank = tensor_process
actual_index = index
if index >= len(
output_attr[0].processes):
actual_index = (
index -
len(output_attr[0].processes)
) % len(output_attr[0].processes)
item = output_attr[0].processes[
actual_index]
if recv_rank == item:
continue
if self.rank_id == item:
# if send bool data, cast then send
self._handle_send(
block, idx, var, op, item,
recv_rank)
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
block, idx, var, op, item,
recv_rank)
else:
for index, tensor_process in enumerate(
tensor_processes):
recv_rank = tensor_process
item = output_attr[0].processes[index]
if recv_rank == item:
continue
if self.rank_id == item:
# if send bool data, cast then send
self._handle_send(
block, idx, var, op, item,
recv_rank)
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
block, idx, var, op, item,
recv_rank)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
def reshard(self):
self._remove_global_process_mesh()
for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks):
# change the var_name before resharding sub block
if block_idx in Resharder.while_block_info:
self._change_subblock_op_input_and_output(block_idx, block)
# reshard input
self._reshard_input(block)
# reshard output
# NOTE: Only support that insert send and recv op if output process mesh is different from tensor process mesh
self._reshard_output(block)
# remove no need vars and ops in the main program
Remover.remove_no_need_in_main(self.auto_parallel_main_prog,
......
......@@ -1419,7 +1419,10 @@ def get_standalone_cost_data(distributed_programs):
}
standalone_cost_data = []
not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
# skip ops
not_enum_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "assign"
]
for distributed_program in distributed_programs:
cost_data = {}
vars = distributed_program.global_block().vars
......
......@@ -27,7 +27,10 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice']
_skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign'
]
# update here to support new optimizers
_supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
......
......@@ -107,7 +107,9 @@ class TestDataUnshard(unittest.TestCase):
input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32")
label_data = np.random.randint(0, 10, [2, 8]).astype("float32")
fetchs = [loss.name, 'input@RESHARD_0']
fetchs = [loss.name, 'split@RESHARD.tmp_0'] if worker_index == 0 else [
loss.name, 'split@RESHARD.tmp_1'
]
loss_np, shard_data_np = exe.run(distributed_main_program,
feed={
"input": input_data,
......
......@@ -28,7 +28,7 @@ from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.process_group import _g_process_group_map, ProcessGroup
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -307,6 +307,7 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
......@@ -326,10 +327,10 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册