未验证 提交 8624f3b1 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard for auto search (#45002)

* update reshard for auto search

* fix unittest bug

* update dist tensor

* update reshard output

* fix unittests bug

* merge develop
上级 236ad4fc
......@@ -276,8 +276,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
# def del_input_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def del_input_dist_attr(self, name):
del self._inputs_dist_attrs[name]
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
......@@ -287,8 +287,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
# def del_output_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def del_output_dist_attr(self, name):
del self._outputs_dist_attrs[name]
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
......
......@@ -163,7 +163,6 @@ class DistributedTensor:
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
self._local_sizes_map = {}
self._local_offsets_map = {}
self._local_shard_map = {}
self._local_tensor_map = {}
......@@ -223,20 +222,17 @@ class DistributedTensor:
return True
def local_sizes(self, rank=None):
"""Get local sizes of the given rank."""
rank = paddle.distributed.get_rank() if rank is None else rank
local_sizes = None
if rank in self._local_sizes_map.keys():
local_sizes = self._local_sizes_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank,
local_sizes = DistributedTensor.get_local_sizes(global_sizes,
dims_mapping, topology,
processes, rank,
shard_sizes)
self._local_sizes_map[rank] = local_sizes
return local_sizes
......@@ -282,7 +278,6 @@ class DistributedTensor:
def new_local_tensor(self, block=None, rank=None, name=None):
"""
Create a new local tensor of serial tensor corresponding to rank.
Args:
block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None.
rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None.
......
......@@ -26,6 +26,11 @@ from ..collective import _get_global_env
from .dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process_group import new_process_group, ProcessGroup, _g_process_group_map
from .cost import build_comm_desc, CommContext
from .cost import AllgatherOpCost, SendOpCost
from .cost import SliceOpCost, SplitOpCost, ConcatOpCost
from .cluster import Cluster
from .utils import print_program_with_dist_attr
# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
......@@ -41,6 +46,7 @@ def get_var_with_recursion(var_name, block, program):
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
return var
......@@ -50,11 +56,19 @@ class AllGatherOpDesc:
Args:
group (list): Process group.
shape (list): The tensor shape.
is_bool (bool): Whether allgather bool data. Default: False.
"""
def __init__(self, group):
def __init__(self, group, shape, is_bool=False):
self._group = group
self._desc = "all_gather"
self._shape = shape
self._is_bool = is_bool
@property
def is_bool(self):
return self._is_bool
@property
def group(self):
......@@ -64,8 +78,12 @@ class AllGatherOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
return self._shape
def __repr__(self):
return f"op: {self._desc}, group: {self._group}."
return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}."
class SendOpDesc:
......@@ -74,13 +92,26 @@ class SendOpDesc:
Args:
partition_index (list): The index of partition in complete tensor.
src (int): The source process to send.
dst (int): The destination process to receive.
is_bool (bool): Whether send bool data. Default: False.
"""
def __init__(self, partition_index, dst):
def __init__(self, partition_index, src, dst, is_bool=False):
self._dst = dst
self._partition_index = partition_index
self._desc = "send"
self._shape = []
self._is_bool = is_bool
self._src = src
@property
def src(self):
return self._src
@property
def is_bool(self):
return self._is_bool
@property
def partition_index(self):
......@@ -94,8 +125,15 @@ class SendOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
if not self._shape:
for item in self.partition_index:
self._shape.append(item[1] - item[0])
return self._shape
def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}."
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}."
class RecvOpDesc:
......@@ -105,12 +143,25 @@ class RecvOpDesc:
Args:
partition_index (list): The index of partition in complete tensor.
src (int): The source process to send.
dst (int): The destination process to receive.
is_bool (bool): Whether receive bool data. Default: False.
"""
def __init__(self, partition_index, src):
def __init__(self, partition_index, src, dst, is_bool=False):
self._src = src
self._partition_index = partition_index
self._desc = "recv"
self._shape = []
self._is_bool = is_bool
self._dst = dst
@property
def dst(self):
return self._dst
@property
def is_bool(self):
return self._is_bool
@property
def partition_index(self):
......@@ -124,8 +175,15 @@ class RecvOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
if not self._shape:
for item in self.partition_index:
self._shape.append(item[1] - item[0])
return self._shape
def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, src: {self._src}."
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}."
class SliceOpDesc:
......@@ -133,16 +191,18 @@ class SliceOpDesc:
Describe the slice op in the reshard phase.
Args:
starts (list): It represents starting indices of corresponding axis in ``axes``.
ends (list): It represents ending indices of corresponding axis in ``axes``.
axes (list): Axes that `starts` and `ends` apply to .
starts (list): It represents start indices of corresponding axis in ``axes``.
ends (list): It represents end indices of corresponding axis in ``axes``.
axes (list): Axes that `starts` and `ends` apply to.
shape (list): The shape of the tensor to be sliced.
"""
def __init__(self, starts, ends, axes):
def __init__(self, starts, ends, axes, shape=None):
self._starts = starts
self._ends = ends
self._axes = axes
self._desc = "slice"
self._shape = shape
@property
def starts(self):
......@@ -160,7 +220,14 @@ class SliceOpDesc:
def desc(self):
return self._desc
@property
def shape(self):
return self._shape
def __repr__(self):
if self._shape is not None:
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}, shape: {self._shape}."
else:
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}."
......@@ -192,36 +259,84 @@ class Inserter:
"""Insert op required in the reshard process."""
@staticmethod
def insert_send_op(block, idx, tensor, dst, op_role):
def insert_cast_op(block, idx, tensor, op_role, tensor_type):
# to avoid name conflict with framework
new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["cast@RESHARD", 'tmp']))
out = block.create_var(name=new_var_name,
dtype=tensor_type,
type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx,
type='cast',
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'in_dtype': tensor.dtype,
'out_dtype': out.dtype,
'op_role': op_role
})
return out
@staticmethod
def insert_send_op(block, idx, tensor, src, dst, op_role):
"""Insert send op into block at the given index."""
op_type = 'send_v2'
# use pair comm group
process_group = new_process_group([src, dst])
block._insert_op(idx,
type=op_type,
inputs={'X': [tensor]},
attrs={
'ring_id': 0,
'peer': dst,
'ring_id': process_group.id,
'peer': process_group.ranks.index(dst),
'use_calc_stream': True,
'op_role': op_role
'op_role': op_role,
'dynamic_shape': True
})
@staticmethod
def insert_recv_op(block, idx, tensor, src, op_role):
def insert_recv_op(block, idx, tensor, src, dst, op_role):
"""Insert recv op into block at the given index."""
op_type = 'recv_v2'
# use pair group
process_group = new_process_group([src, dst])
block._insert_op(idx,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': 0,
'peer': src,
'ring_id': process_group.id,
'peer': process_group.ranks.index(src),
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': True,
'op_role': op_role
'op_role': op_role,
'dynamic_shape': True
})
@staticmethod
def insert_reset_lod_op(block, idx, X, Y, op_role):
"""Insert reset_lod op into block at the given index."""
new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["reset_lod@RESHARD", 'tmp']))
reset_lod_out = block.create_var(name=new_var_name,
shape=X.shape,
type=X.type,
dtype=X.dtype,
lod_level=X.lod_level)
block._insert_op(idx,
type="lod_reset",
inputs={
'X': X,
'Y': Y
},
outputs={'Out': reset_lod_out},
attrs={'op_role': op_role})
return reset_lod_out
@staticmethod
def insert_concat_op(block, idx, tensors, axis, op_role):
"""Insert concat op into block at the given block."""
......@@ -229,10 +344,18 @@ class Inserter:
attrs = {}
attrs['axis'] = axis
attrs['op_role'] = op_role
helper = LayerHelper('concat', **locals())
# to avoid name conflict with framework
helper = LayerHelper('concat@RESHARD', **locals())
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensors[0].dtype,
shape=None,
lod_level=tensors[0].lod_level,
type=tensors[0].type,
persistable=False,
stop_gradient=False)
block._insert_op(idx,
type='concat',
inputs=inputs,
......@@ -244,6 +367,72 @@ class Inserter:
def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_role):
"""Insert slice op into block at the given block."""
# This is a hack to insert split op to get slice tensor
# 1. [128, 128] => [64, 128]: split
# 2. [128, 128] => [128, 128]: assign
# 3. [128, 128] => [64, 64]: slice, it will replaced by multi split
global_shape = tensor.shape
slice_shape = [ends[i] - starts[i] for i in range(len(starts))]
diff_dims = []
for index, item in enumerate(slice_shape):
if item != global_shape[index]:
diff_dims.append(index)
# use assign
if len(diff_dims) == 0:
out = block.create_var(name=new_var_name,
dtype=tensor.dtype,
type=tensor.type,
shape=slice_shape,
lod_level=tensor.lod_level)
inputs = {'X': [tensor]}
outputs = {"Out": [out]}
attrs = {"in_place": False}
block._insert_op(idx,
type="assign",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return out
# use split once
elif len(diff_dims) == 1:
diff_dim = diff_dims[0]
num_or_sections = global_shape[diff_dim] // slice_shape[diff_dim]
axis = diff_dim
cur_idx = starts[diff_dim] // slice_shape[diff_dim]
input_shape = global_shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role}
new_shape = []
for index, item in enumerate(tensor.shape):
if index != axis:
new_shape.append(item)
else:
new_shape.append(item // num_or_sections)
with paddle.static.program_guard(block.program):
outs = [
block.create_var(name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(
['split@RESHARD', 'tmp'])),
dtype=tensor.dtype,
shape=None,
type=tensor.type,
persistable=False,
lod_level=tensor.lod_level,
stop_gradient=False)
for i in range(num_or_sections)
]
out = outs[cur_idx]
op = block._insert_op(idx,
type="split",
inputs=inputs,
outputs={'Out': outs},
attrs=attrs)
return out
# use slice
else:
inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
attrs = {
......@@ -253,28 +442,42 @@ class Inserter:
"infer_flags": infer_flags,
'op_role': op_role
}
helper = LayerHelper('slice', **locals())
out = block.create_var(name=new_var_name,
dtype=tensor.dtype,
type=tensor.type)
type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx,
type="slice",
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
return out
@staticmethod
def insert_split_op(block, idx, tensor, num_or_sections, op_role):
def insert_split_op(block, idx, tensor, num_or_sections, op_role, axis=0):
"""Insert split op into block at the given index."""
helper = LayerHelper('split', **locals())
helper = LayerHelper('split@RESHARD', **locals())
input_shape = tensor.shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role}
attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role}
new_shape = []
for index, item in enumerate(tensor.shape):
if index != axis:
new_shape.append(item)
else:
new_shape.append(item // num_or_sections)
with paddle.static.program_guard(block.program):
outs = [
helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) for i in range(num_or_sections)
block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False) for i in range(num_or_sections)
]
block._insert_op(idx,
type="split",
......@@ -286,9 +489,18 @@ class Inserter:
@staticmethod
def insert_fill_constant_op(block, idx, op_role):
"""Insert fill constant op into block at the given index."""
helper = LayerHelper("fill_constant", **locals())
# to avoid name conflict with framework
helper = LayerHelper('fill_constant@RESHARD', **locals())
# use paddle.int64 as dtype
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(dtype="int32")
out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=paddle.int64,
shape=None,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
inputs = {}
attrs = {'force_cpu': False}
attrs['str_value'] = str(int("1"))
......@@ -342,10 +554,18 @@ class Inserter:
# insert c_allgather op
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
# to avoid name conflict with framework
helper = LayerHelper(op_type + "@RESHARD", **locals())
with paddle.static.program_guard(block.program):
allgather_out = helper.create_variable_for_type_inference(
dtype=tensor.dtype)
allgather_out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False)
block._insert_op(idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
......@@ -620,12 +840,14 @@ class Resharder:
batch_size=None):
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog))
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \
if auto_parallel_startup_prog is not None:
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program or None, " \
"but got {}.".format(type(auto_parallel_startup_prog))
assert isinstance(rank_id, int), "The type of rank_id should be int, " \
"but got {}.".format(type(rank_id))
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
if batch_size is not None:
assert isinstance(batch_size, int), "The type of batch_size should be int, " \
"but got {}.".format(type(batch_size))
......@@ -639,6 +861,8 @@ class Resharder:
self._has_sent = {}
self._has_recv = {}
self._has_allgather = {}
# to avoid reshard repeatly
self._has_resharded = {}
@property
def auto_parallel_main_prog(self):
......@@ -798,7 +1022,10 @@ class Resharder:
for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
if dist_op or (op.type == "slice" and not dist_op) or (
op.type == "split"
and not dist_op) or (op.type == "assign"
and not dist_op):
for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name)
......@@ -812,7 +1039,8 @@ class Resharder:
while_op = op
break
assert while_op is not None
if while_op is None:
continue
# find the actual input and output of while op
proto = OpProtoHolder.instance().get_op_proto(while_op.type)
......@@ -821,12 +1049,14 @@ class Resharder:
if var_name in sub_block_op_inputs:
new_X.append(var_name)
assert new_X
new_X.sort()
while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
if output_name not in new_Out:
new_Out.append(output_name)
assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out)
......@@ -870,120 +1100,72 @@ class Resharder:
return True
def need_reshard(self,
dist_tensor,
dist_op,
actual_process_mesh,
op_input=True):
def need_reshard(self, dist_tensor, dist_attr, op_input=True, dist_op=None):
"""Judge the tensor whether needs to be resharded."""
is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name
tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = actual_process_mesh
# dist_attr is [process_mesh, dims_mapping] and process_mesh is not a union
op_process_mesh = dist_attr[0]
if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name)
op_input_dims_mapping = dist_attr[1]
if all(
map(lambda x: x is not None, [
map(lambda x: x, [
tensor_dims_mapping, tensor_process_mesh,
op_input_dims_mapping, op_process_mesh
])):
# dims_mapping
# judge whether need reshard by dims_mapping
if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
var_dims_mapping = dist_op_attr.get_input_dims_mapping(
var_name)
if var_dims_mapping != tensor_dims_mapping:
is_reshard = True
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert self.is_unshard(tensor_dims_mapping)
assert self.is_unshard(op_input_dims_mapping)
if tensor_process_mesh not in self.dist_context.process_meshes:
# assert whether -1 when union.
for item in tensor_dims_mapping:
if item != -1:
raise ValueError(
"The dim must be -1 when tensor process mesh is a union."
)
# tensor process_mesh: [0, 1, 2, 3], dims_mapping: [-1, -1]
# op process_mesh: [4, 5], dims_mapping: [0, -1]
# reshard is not supported such as above
if not is_reshard:
return is_reshard
else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError(
"Bool var is not supported reshard.")
# for while op, it should find the process mesh of op actually used the tensor as input
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
process_mesh = dist_op_attr.process_mesh
if process_mesh == op_process_mesh:
"it is not supported that tensor process mesh is a union and needs reshard."
)
is_reshard = True
break
else:
# judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh:
is_reshard = True
else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
op_output_dims_mapping = dist_attr[1]
if all(
map(lambda x: x is not None, [
map(lambda x: x, [
tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh
])):
if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
if tensor_process_mesh != op_process_mesh:
is_reshard = True
return is_reshard
def get_process_meshes(self, op):
"""Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def get_op_process_meshes(self, op):
"""Get sub process meshes of the given op if op process mesh is a union."""
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (set(
op_process_mesh.processes)) and len(
process_mesh.processes) <= len(
process_mesh.processes) < len(
op_process_mesh.processes):
process_meshes.append(process_mesh)
......@@ -993,39 +1175,14 @@ class Resharder:
return process_meshes
def get_while_op_actual_process_mesh(self, op):
"""Get the while op actual Process mesh corresponding to rank"""
assert op.type == "while"
while_op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue
if self.rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break
if actual_process_mesh is None and self.rank_id in while_op_process_mesh.processes:
actual_process_mesh = while_op_process_mesh
assert actual_process_mesh is not None
return actual_process_mesh
def find_op_desc_seq(self, dist_tensor, dist_op, actual_process_mesh):
def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
dist_attr (list): A list contains process_mesh and dims_mapping such as [process_mesh, dims_mapping].
serial (bool): If serial is true, the dist tensor and dist op come from serial program. Otherwise, they come from auto program.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
......@@ -1034,24 +1191,26 @@ class Resharder:
tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr
target_process_mesh = actual_process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_mesh = dist_attr[0]
target_dims_mapping = dist_attr[1]
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0:
assert source_tensor.shape[0] == -1
new_shape = list(source_tensor.shape)
new_shape[0] = self.batch_size
source_tensor.desc.set_shape(new_shape)
complete_shape = Resharder.compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping)
source_tensor.shape, source_process_shape,
source_dims_mapping) if not serial else source_tensor.shape
op_desc_seq = {}
# TODO: if the target process group has the same process with source process group
......@@ -1060,13 +1219,14 @@ class Resharder:
set(source_process_group)):
pass
# in the different process group, it will use send, recv, concat and slice op
elif target_process_group != source_process_group:
partition_process_mapping_list = []
for source_process in source_process_group:
# get partition index of source process
source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \
source_process_shape, source_process_group)
if not partition_process_mapping_list:
# the item in partition_process_mapping_list is source_partition_index, which processes and whether has been used
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
else:
......@@ -1076,6 +1236,7 @@ class Resharder:
[item[1] for item in partition_process_mapping_list])
has_used = list(
[item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index)
process_list[index].append(source_process)
......@@ -1085,6 +1246,7 @@ class Resharder:
[source_partition_index, [source_process], [False]])
for target_process in target_process_group:
# has_sent means the source_partition_index has been sent to target_process
has_sent = []
target_partition_index = Resharder.compute_partition_index(
target_process, complete_shape, target_dims_mapping,
......@@ -1114,6 +1276,7 @@ class Resharder:
has_used[i] = True
break
i += 1
if i == len(has_used):
has_used = list(map(lambda x: False, has_used))
to_send_process = process_list[0]
......@@ -1127,10 +1290,16 @@ class Resharder:
all_partition_index_list.append(source_partition_index)
# append send and recv op desc
is_bool = (
dist_tensor.serial_tensor.dtype == paddle.bool)
send_op_desc = SendOpDesc(source_partition_index,
target_process)
to_send_process,
target_process,
is_bool=is_bool)
recv_op_desc = RecvOpDesc(source_partition_index,
to_send_process)
to_send_process,
target_process,
is_bool=is_bool)
op_desc_seq[to_send_process].append(send_op_desc)
op_desc_seq[target_process].append(recv_op_desc)
has_sent.append(source_partition_index)
......@@ -1146,16 +1315,24 @@ class Resharder:
slice_ends = []
slices_axes = []
concatenated_partition_index = partition_index_list[0]
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] -
item[0])
slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes))
SliceOpDesc(slice_starts,
slice_ends,
slices_axes,
shape=to_slice_tensor_shape))
# in the same process group, it will use allgahther and slice op
# in the same process group, it will use allgahther and slice op.
else:
# NOTE: It just supports even partition scene.
partition_index_list = []
all_partition_index_list = []
process_index = []
......@@ -1191,17 +1368,21 @@ class Resharder:
slice_ends.append(item[1])
slices_axes.append(idx)
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc(starts=slice_starts,
ends=slice_ends,
axes=slices_axes)
op_desc_seq[process] = [AllGatherOpDesc(group=group),
axes=slices_axes,
shape=to_slice_tensor_shape)
allgather_shape = None if not serial else dist_tensor.local_sizes(
rank=process)
op_desc_seq[process] = [AllGatherOpDesc(group=group, shape=allgather_shape, is_bool=(source_tensor.dtype == paddle.bool)),
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \
if len(group) > 1 else [slice_op_desc]
return op_desc_seq
def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
actual_process_mesh):
dist_attr):
"""Parse op desc sequence and insert op in the block"""
tensor_list = []
partition_tensor_list = []
......@@ -1226,6 +1407,25 @@ class Resharder:
self.has_allgather[var_name] = []
if not self.has_allgather[var_name] or op_desc.group not in list(
map(lambda x: x[0], self.has_allgather[var_name])):
if op_desc.is_bool:
# for bool data allgather, cast to int64 -> allgather -> cast bool
out_cast = Inserter.insert_cast_op(
block, idx, source_tensor,
reshard_op.attr('op_role'), paddle.int64)
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx + 1, out_cast, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = []
for var in tensor_list:
out_cast = Inserter.insert_cast_op(
block, idx, var, reshard_op.attr('op_role'),
paddle.bool)
tensor_name_list.append(out_cast.name)
idx += 1
self.has_allgather[var_name].append(
[op_desc.group, tensor_name_list])
else:
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role'))
......@@ -1249,8 +1449,17 @@ class Resharder:
if var_name not in self.has_sent.keys():
self.has_sent[var_name] = []
if op_desc.dst not in self.has_sent[var_name]:
if op_desc.is_bool:
out_cast = Inserter.insert_cast_op(
block, idx, source_tensor,
reshard_op.attr('op_role'), paddle.int64)
Inserter.insert_send_op(block, idx + 1, out_cast,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
idx += 2
else:
Inserter.insert_send_op(block, idx, source_tensor,
op_desc.dst,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
idx += 1
self.has_sent[var_name].append(op_desc.dst)
......@@ -1263,14 +1472,55 @@ class Resharder:
shape = []
for index in partition_index:
shape.append(index[1] - index[0])
if op_desc.is_bool:
# for bool data, recv int64 -> cast to bool
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
lod_level=source_tensor.lod_level,
dtype=paddle.int64,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
out_cast = Inserter.insert_cast_op(
block, idx + 1, recv_tensor,
reshard_op.attr('op_role'), paddle.bool)
tensor_list.append(out_cast)
idx += 2
self.has_recv[var_name][op_desc.src] = out_cast
else:
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
lod_level=source_tensor.lod_level,
dtype=source_tensor.dtype,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
# for lod tensor, need reset lod after received
if recv_tensor.lod_level != 0:
set_lod = False
# use data lod to reset tensor lod
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == recv_tensor.lod_level:
reset_lod_out = Inserter.insert_reset_lod_op(
block, idx + 1, recv_tensor,
tmp_var, reshard_op.attr('op_role'))
tensor_list.append(reset_lod_out)
idx += 2
self.has_recv[var_name][
op_desc.src] = reset_lod_out
set_lod = True
break
if set_lod:
break
assert set_lod is True
else:
tensor_list.append(recv_tensor)
idx += 1
self.has_recv[var_name][op_desc.src] = recv_tensor
......@@ -1303,90 +1553,226 @@ class Resharder:
new_var_name=new_name,
op_role=reshard_op.attr('op_role'))
process_mesh = dist_attr[0]
dims_mapping = dist_attr[1]
tensor_attr = TensorDistributedAttribute()
process_mesh = actual_process_mesh
dims_mapping = self.dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
target_tensor, tensor_attr)
if op.type == "while":
if matched_op.type == "while":
# var_reshard_mapping means the while op input need be changed to
if "var_reshard_mapping" not in Resharder.while_block_info[
op.attr("sub_block").id].keys():
Resharder.while_block_info[op.attr(
"sub_block").id]["var_reshard_mapping"] = {}
if var_name not in Resharder.while_block_info[op.attr(
"sub_block").id]["var_reshard_mapping"].keys():
Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = []
Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = target_tensor.name
"var_reshard_mapping"][var_name].append(
[dist_attr, target_tensor.name])
# rename op input name according to new name
for op in block.ops:
# just for while op
while_op_X_append = []
for name in op.input_arg_names:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
op)
if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id():
op.desc._rename_input(name, target_tensor.name)
if matched_op.type == "while":
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
new_name, dims_mapping)
if old_name in op_dist_attr._inputs_dist_attrs:
op_dist_attr.del_input_dist_attr(
old_name)
while_op_X_append.append(new_name)
continue
else:
op.desc._rename_input(
name, target_tensor.name)
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping)
op_dist_attr.del_input_dist_attr(old_name)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name)
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None)
new_name, dims_mapping)
op_dist_attr.del_input_dist_attr(old_name)
def reshard(self):
for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks):
if block_idx in Resharder.while_block_info:
if "var_reshard_mapping" in Resharder.while_block_info[
block_idx]:
# for while op, the input X should reset
if while_op_X_append:
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name,
op.input("X") + while_op_X_append)
def _get_while_op_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
input_attrs = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
dist_attr = dist_op.dist_attr
for name in op.input_arg_names:
if name == var_name:
process_mesh = dist_attr.process_mesh
input_dims_mapping = dist_attr.get_input_dims_mapping(
var_name)
has_exist = False
for input_attr in input_attrs:
if process_mesh == input_attr[
0] and input_dims_mapping == input_attr[1]:
has_exist = True
break
if not has_exist:
input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs
def _get_common_op_input_attrs(self, op, var_name):
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
dist_attr = dist_op.dist_attr
op_process_mesh = dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (set(
op_process_mesh.processes)) and len(
process_mesh.processes) < len(
op_process_mesh.processes):
process_meshes.append(process_mesh)
# it means that the process mesh is not a union when process meshes is none
if not process_meshes:
process_meshes.append(op_process_mesh)
input_dims_mapping = dist_attr.get_input_dims_mapping(var_name)
input_attrs = []
for process_mesh in process_meshes:
input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs
def get_op_input_attrs(self, op, var_name):
op_input_attrs = []
if op.type == "while":
op_input_attrs = self._get_while_op_input_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
assert op_input_attrs
return op_input_attrs
def _remove_global_process_mesh(self):
"""Remove global process mesh from dist_context.process_meshes"""
processes = set()
process_mesh_count = len(self.dist_context.process_meshes)
if process_mesh_count > 1:
global_process_mesh_idx = None
for process_mesh in self.dist_context.process_meshes:
for process in process_mesh.processes:
processes.add(process)
for idx, process_mesh in enumerate(
self.dist_context.process_meshes):
if len(set(process_mesh.processes)) == len(processes):
global_process_mesh_idx = idx
break
if global_process_mesh_idx is not None:
self.dist_context.process_meshes.pop(idx)
def _change_subblock_op_input_and_output(self, block_idx, block):
if "var_reshard_mapping" in Resharder.while_block_info[block_idx]:
var_reshard_mapping = Resharder.while_block_info[block_idx][
"var_reshard_mapping"]
for op in block.ops:
for var_name in op.input_arg_names:
if var_name in var_reshard_mapping:
op.desc._rename_input(
var_name, var_reshard_mapping[var_name])
# in while sub block, the union process mesh is not split before reshard sub block
dist_op = self.dist_context.get_dist_op_for_program(op)
dist_attr = dist_op.dist_attr
target_name = None
for item in var_reshard_mapping[var_name]:
if dist_attr.process_mesh == item[0][
0] and dist_attr.get_input_dims_mapping(
var_name) == item[0][1]:
target_name = item[1]
break
if target_name is None:
continue
else:
op.desc._rename_input(var_name, target_name)
dist_op = self.dist_context.get_dist_op_for_program(
op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
op_dist_attr.set_input_dims_mapping(
var_reshard_mapping[var_name],
dims_mapping)
old_name = var_name
new_name = target_name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
var_name, None)
new_name, op_input_dist_attr)
op_dist_attr.del_input_dist_attr(old_name)
# the outputs also need to be renamed when the output name is the same with input name
# the outputs also need to be renamed when the output name is the same with input name in inplace op
for var_name in op.output_arg_names:
# if the tensor has been resharded multiply, it is not supported now.
if var_name in var_reshard_mapping:
op.desc._rename_output(
var_name, var_reshard_mapping[var_name])
dist_op = self.dist_context.get_dist_op_for_program(
op)
if len(var_reshard_mapping[var_name]) > 1:
raise ValueError(
"The scene is not supported that the output is inplaced and the tensor has been resharded multiply when as input."
)
target_name = var_reshard_mapping[var_name][0][1]
op.desc._rename_output(var_name, target_name)
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[
block_idx]["actual_process_mesh"]:
dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
op_dist_attr.set_output_dims_mapping(
var_reshard_mapping[var_name],
dims_mapping)
old_name = var_name
new_name = target_name
assert old_name != new_name
op_output_dist_attr = op_dist_attr.get_output_dist_attr(
old_name)
op_dist_attr.set_output_dist_attr(
var_name, None)
new_name, op_output_dist_attr)
op_dist_attr.del_output_dist_attr(old_name)
def _reshard_input(self, block):
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
......@@ -1398,46 +1784,62 @@ class Resharder:
dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None:
process_meshes = []
op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while":
if not self.is_condition_replicative(op):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
)
process_meshes = self.get_process_meshes(op)
assert process_meshes
if op.attr("sub_block"
).id not in Resharder.while_block_info:
Resharder.while_block_info[op.attr(
"sub_block").id] = {}
if op.attr(
"sub_block").id not in Resharder.while_block_info:
Resharder.while_block_info[op.attr("sub_block").id] = {}
Resharder.while_block_info[op.attr(
"sub_block").id]["op_id"] = op.desc.id()
Resharder.while_block_info[op.attr("sub_block").id][
"actual_process_mesh"] = self.get_while_op_actual_process_mesh(
op)
else:
process_meshes = self.get_op_process_meshes(op)
input_vars = None
if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X")
else:
input_var_names = op.input_arg_names
# to avoid while op X order different
input_var_names.sort()
idx_offset = 0
for var_name in op.input_arg_names:
for var_name in input_var_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = get_var_with_recursion(
var_name, block, self.auto_parallel_main_prog)
var = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
for process_mesh in process_meshes:
# judge whether union tensor dims_mapping all -1
is_union_process_mesh_tensor = False
if dist_tensor.dist_attr.process_mesh not in self.dist_context.process_meshes and self.dist_context.process_meshes:
is_union_process_mesh_tensor = True
assert dist_tensor.dist_attr.dims_mapping.count(
-1) == len(dist_tensor.dist_attr.dims_mapping)
op_input_attrs = self.get_op_input_attrs(op, var_name)
for input_attr in op_input_attrs:
input_process_mesh = None
# deal with union tensor
if is_union_process_mesh_tensor:
# if op process mesh is subset of union tensor process mesh, need no reshard
if set(input_attr[0].processes) <= set(
dist_tensor.dist_attr.process_mesh.processes
):
continue
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh):
dist_tensor, input_attr):
reshard_op_desc = self.find_op_desc_seq(
dist_tensor, dist_op, process_mesh)
self.parse_op_desc(block, reshard_op_desc,
var_name, op, process_mesh)
dist_tensor, input_attr)
self.parse_op_desc(block, reshard_op_desc, var_name,
op, input_attr)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
......@@ -1445,12 +1847,113 @@ class Resharder:
else:
idx += 1
def _hadnle_recv(self, block, idx, var, op, send_rank, recv_rank):
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
if var.dtype == paddle.bool:
recv_cast_out = block.create_var(
name=unique_name.generate(var.name + "@recv"),
shape=var.shape,
lod_level=var.lod_level,
dtype=paddle.int64,
type=var.type)
Inserter.insert_recv_op(block, idx + 1,
recv_cast_out, send_rank, recv_rank,
op.attr('op_role'))
reset_lod_out = None
if var.lod_level != 0:
set_lod = False
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == var.lod_level:
reset_lod_out = block.create_var(
name=unique_name.generate(var.name +
"@RESETLOD"),
shape=recv_cast_out.shape,
type=recv_cast_out.type,
dtype=recv_cast_out.dtype,
lod_level=recv_cast_out.lod_level)
idx += 1
block._insert_op(
idx,
type="lod_reset",
inputs={
'X': recv_cast_out,
'Y': tmp_var
},
outputs={'Out': reset_lod_out},
attrs={'op_role': op.attr("op_role")})
set_lod = True
break
if set_lod:
break
assert set_lod is True
# cast int64 to bool
block._insert_op(idx + 2,
type='cast',
inputs={
'X': [recv_cast_out] if
reset_lod_out is None else [reset_lod_out]
},
outputs={'Out': [var]},
attrs={
'in_dtype': recv_cast_out.dtype,
'out_dtype': var.dtype,
'op_role': op.attr('op_role')
})
else:
if var.lod_level != 0:
recv_out = block.create_var(
name=unique_name.generate(var.name + "@recv"),
shape=var.shape,
lod_level=var.lod_level,
dtype=var.int64,
type=var.type)
Inserter.insert_recv_op(block, idx + 1, recv_out, send_rank,
recv_rank, op.attr('op_role'))
set_lod = False
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == var.lod_level:
idx += 1
block._insert_op(
idx,
type="lod_reset",
inputs={
'X': recv_out,
'Y': tmp_var
},
outputs={'Out': var},
attrs={'op_role': op.attr("op_role")})
set_lod = True
break
if set_lod:
break
assert set_lod is True
else:
Inserter.insert_recv_op(block, idx + 1, var, send_rank,
recv_rank, op.attr('op_role'))
def _handle_send(self, block, idx, var, op, send_rank, recv_rank):
if var.dtype == paddle.bool:
cast_out = Inserter.insert_cast_op(block, idx + 1, var,
op.attr('op_role'), paddle.int64)
Inserter.insert_send_op(block, idx + 2, cast_out, send_rank,
recv_rank, op.attr('op_role'))
else:
Inserter.insert_send_op(block, idx + 1, var, send_rank, recv_rank,
op.attr('op_role'))
def _reshard_output(self, block):
# insert send and recv op if output process mesh is different from tensor process mesh
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read",
"while", "write_to_array", "read_from_array"
"create_py_reader", "create_double_buffer_reader", "read", "while",
"write_to_array", "read_from_array"
]
global _g_special_ops
skip_ops += _g_special_ops
......@@ -1459,33 +1962,98 @@ class Resharder:
op = block.ops[idx]
dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops:
idx_offset = 0
for var_name in op.output_arg_names:
var = get_var_with_recursion(
var_name, block, self.auto_parallel_main_prog)
var = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(
var)
process_mesh = dist_op.dist_attr.process_mesh
tensor_process_mesh = dist_tensor.dist_attr.process_mesh
output_attr = [
dist_op.dist_attr.process_mesh,
dist_op.dist_attr.get_output_dims_mapping(var_name)
]
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh, False):
for index, item in enumerate(
dist_op.dist_attr.process_mesh.processes):
recv_rank = dist_tensor.dist_attr.process_mesh.processes[
index]
dist_tensor, output_attr, False):
tensor_processes = set(
tensor_process_mesh.processes) - (
set(tensor_process_mesh.processes)
& set(output_attr[0].processes))
if tensor_processes:
if len(tensor_processes) != len(
output_attr[0].processes):
if dist_tensor.dist_attr.dims_mapping.count(
-1) != len(
dist_tensor.dist_attr.dims_mapping
) or output_attr[1].count(-1) != len(
output_attr[1]):
raise ValueError(
"The dims_mapping must be -1")
else:
for index, tensor_process in enumerate(
tensor_processes):
recv_rank = tensor_process
actual_index = index
if index >= len(
output_attr[0].processes):
actual_index = (
index -
len(output_attr[0].processes)
) % len(output_attr[0].processes)
item = output_attr[0].processes[
actual_index]
if recv_rank == item:
continue
if self.rank_id == item:
Inserter.insert_send_op(
block, idx + 1, var, recv_rank,
op.attr('op_role'))
# if send bool data, cast then send
self._handle_send(
block, idx, var, op, item,
recv_rank)
if self.rank_id == recv_rank:
Inserter.insert_recv_op(
block, idx + 1, var, item,
op.attr('op_role'))
# if recv bool data, recv then cast
self._hadnle_recv(
block, idx, var, op, item,
recv_rank)
else:
for index, tensor_process in enumerate(
tensor_processes):
recv_rank = tensor_process
item = output_attr[0].processes[index]
if recv_rank == item:
continue
if self.rank_id == item:
# if send bool data, cast then send
self._handle_send(
block, idx, var, op, item,
recv_rank)
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
block, idx, var, op, item,
recv_rank)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
def reshard(self):
self._remove_global_process_mesh()
for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks):
# change the var_name before resharding sub block
if block_idx in Resharder.while_block_info:
self._change_subblock_op_input_and_output(block_idx, block)
# reshard input
self._reshard_input(block)
# reshard output
# NOTE: Only support that insert send and recv op if output process mesh is different from tensor process mesh
self._reshard_output(block)
# remove no need vars and ops in the main program
Remover.remove_no_need_in_main(self.auto_parallel_main_prog,
self.dist_context, self.rank_id,
......
......@@ -1419,7 +1419,10 @@ def get_standalone_cost_data(distributed_programs):
}
standalone_cost_data = []
not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
# skip ops
not_enum_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "assign"
]
for distributed_program in distributed_programs:
cost_data = {}
vars = distributed_program.global_block().vars
......
......@@ -27,7 +27,10 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice']
_skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign'
]
# update here to support new optimizers
_supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
......
......@@ -107,7 +107,9 @@ class TestDataUnshard(unittest.TestCase):
input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32")
label_data = np.random.randint(0, 10, [2, 8]).astype("float32")
fetchs = [loss.name, 'input@RESHARD_0']
fetchs = [loss.name, 'split@RESHARD.tmp_0'] if worker_index == 0 else [
loss.name, 'split@RESHARD.tmp_1'
]
loss_np, shard_data_np = exe.run(distributed_main_program,
feed={
"input": input_data,
......
......@@ -28,7 +28,7 @@ from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.process_group import _g_process_group_map, ProcessGroup
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -307,6 +307,7 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
......@@ -326,10 +327,10 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册