未验证 提交 8624f3b1 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update reshard for auto search (#45002)

* update reshard for auto search

* fix unittest bug

* update dist tensor

* update reshard output

* fix unittests bug

* merge develop
上级 236ad4fc
...@@ -276,8 +276,8 @@ class OperatorDistributedAttribute: ...@@ -276,8 +276,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr) dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object self._inputs_dist_attrs[name] = dist_attr_object
# def del_input_dist_attr(self, name): def del_input_dist_attr(self, name):
# del self._inputs_dist_attrs[name] del self._inputs_dist_attrs[name]
def get_output_dist_attr(self, name): def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None) return self._outputs_dist_attrs.get(name, None)
...@@ -287,8 +287,8 @@ class OperatorDistributedAttribute: ...@@ -287,8 +287,8 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr) dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object self._outputs_dist_attrs[name] = dist_attr_object
# def del_output_dist_attr(self, name): def del_output_dist_attr(self, name):
# del self._inputs_dist_attrs[name] del self._outputs_dist_attrs[name]
def get_input_dims_mapping(self, name): def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name) input_dist_attr = self.get_input_dist_attr(name)
......
...@@ -163,7 +163,6 @@ class DistributedTensor: ...@@ -163,7 +163,6 @@ class DistributedTensor:
self._batch_dim = 0 self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr # Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr self.dist_attr = dist_attr
self._local_sizes_map = {}
self._local_offsets_map = {} self._local_offsets_map = {}
self._local_shard_map = {} self._local_shard_map = {}
self._local_tensor_map = {} self._local_tensor_map = {}
...@@ -223,20 +222,17 @@ class DistributedTensor: ...@@ -223,20 +222,17 @@ class DistributedTensor:
return True return True
def local_sizes(self, rank=None): def local_sizes(self, rank=None):
"""Get local sizes of the given rank."""
rank = paddle.distributed.get_rank() if rank is None else rank rank = paddle.distributed.get_rank() if rank is None else rank
local_sizes = None
if rank in self._local_sizes_map.keys():
local_sizes = self._local_sizes_map[rank]
else:
global_sizes = self.serial_tensor.shape global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes( local_sizes = DistributedTensor.get_local_sizes(global_sizes,
global_sizes, dims_mapping, topology, processes, rank, dims_mapping, topology,
processes, rank,
shard_sizes) shard_sizes)
self._local_sizes_map[rank] = local_sizes
return local_sizes return local_sizes
...@@ -282,7 +278,6 @@ class DistributedTensor: ...@@ -282,7 +278,6 @@ class DistributedTensor:
def new_local_tensor(self, block=None, rank=None, name=None): def new_local_tensor(self, block=None, rank=None, name=None):
""" """
Create a new local tensor of serial tensor corresponding to rank. Create a new local tensor of serial tensor corresponding to rank.
Args: Args:
block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None. block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None.
rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None. rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None.
......
...@@ -26,6 +26,11 @@ from ..collective import _get_global_env ...@@ -26,6 +26,11 @@ from ..collective import _get_global_env
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process_group import new_process_group, ProcessGroup, _g_process_group_map from .process_group import new_process_group, ProcessGroup, _g_process_group_map
from .cost import build_comm_desc, CommContext
from .cost import AllgatherOpCost, SendOpCost
from .cost import SliceOpCost, SplitOpCost, ConcatOpCost
from .cluster import Cluster
from .utils import print_program_with_dist_attr
# NOTE: If op in _g_special_ops, it will not be resharded. # NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
...@@ -41,6 +46,7 @@ def get_var_with_recursion(var_name, block, program): ...@@ -41,6 +46,7 @@ def get_var_with_recursion(var_name, block, program):
if var_name in parent_block.vars: if var_name in parent_block.vars:
var = parent_block.vars[var_name] var = parent_block.vars[var_name]
assert var is not None assert var is not None
return var return var
...@@ -50,11 +56,19 @@ class AllGatherOpDesc: ...@@ -50,11 +56,19 @@ class AllGatherOpDesc:
Args: Args:
group (list): Process group. group (list): Process group.
shape (list): The tensor shape.
is_bool (bool): Whether allgather bool data. Default: False.
""" """
def __init__(self, group): def __init__(self, group, shape, is_bool=False):
self._group = group self._group = group
self._desc = "all_gather" self._desc = "all_gather"
self._shape = shape
self._is_bool = is_bool
@property
def is_bool(self):
return self._is_bool
@property @property
def group(self): def group(self):
...@@ -64,8 +78,12 @@ class AllGatherOpDesc: ...@@ -64,8 +78,12 @@ class AllGatherOpDesc:
def desc(self): def desc(self):
return self._desc return self._desc
@property
def shape(self):
return self._shape
def __repr__(self): def __repr__(self):
return f"op: {self._desc}, group: {self._group}." return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}."
class SendOpDesc: class SendOpDesc:
...@@ -74,13 +92,26 @@ class SendOpDesc: ...@@ -74,13 +92,26 @@ class SendOpDesc:
Args: Args:
partition_index (list): The index of partition in complete tensor. partition_index (list): The index of partition in complete tensor.
src (int): The source process to send.
dst (int): The destination process to receive. dst (int): The destination process to receive.
is_bool (bool): Whether send bool data. Default: False.
""" """
def __init__(self, partition_index, dst): def __init__(self, partition_index, src, dst, is_bool=False):
self._dst = dst self._dst = dst
self._partition_index = partition_index self._partition_index = partition_index
self._desc = "send" self._desc = "send"
self._shape = []
self._is_bool = is_bool
self._src = src
@property
def src(self):
return self._src
@property
def is_bool(self):
return self._is_bool
@property @property
def partition_index(self): def partition_index(self):
...@@ -94,8 +125,15 @@ class SendOpDesc: ...@@ -94,8 +125,15 @@ class SendOpDesc:
def desc(self): def desc(self):
return self._desc return self._desc
@property
def shape(self):
if not self._shape:
for item in self.partition_index:
self._shape.append(item[1] - item[0])
return self._shape
def __repr__(self): def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}." return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}."
class RecvOpDesc: class RecvOpDesc:
...@@ -105,12 +143,25 @@ class RecvOpDesc: ...@@ -105,12 +143,25 @@ class RecvOpDesc:
Args: Args:
partition_index (list): The index of partition in complete tensor. partition_index (list): The index of partition in complete tensor.
src (int): The source process to send. src (int): The source process to send.
dst (int): The destination process to receive.
is_bool (bool): Whether receive bool data. Default: False.
""" """
def __init__(self, partition_index, src): def __init__(self, partition_index, src, dst, is_bool=False):
self._src = src self._src = src
self._partition_index = partition_index self._partition_index = partition_index
self._desc = "recv" self._desc = "recv"
self._shape = []
self._is_bool = is_bool
self._dst = dst
@property
def dst(self):
return self._dst
@property
def is_bool(self):
return self._is_bool
@property @property
def partition_index(self): def partition_index(self):
...@@ -124,8 +175,15 @@ class RecvOpDesc: ...@@ -124,8 +175,15 @@ class RecvOpDesc:
def desc(self): def desc(self):
return self._desc return self._desc
@property
def shape(self):
if not self._shape:
for item in self.partition_index:
self._shape.append(item[1] - item[0])
return self._shape
def __repr__(self): def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, src: {self._src}." return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}."
class SliceOpDesc: class SliceOpDesc:
...@@ -133,16 +191,18 @@ class SliceOpDesc: ...@@ -133,16 +191,18 @@ class SliceOpDesc:
Describe the slice op in the reshard phase. Describe the slice op in the reshard phase.
Args: Args:
starts (list): It represents starting indices of corresponding axis in ``axes``. starts (list): It represents start indices of corresponding axis in ``axes``.
ends (list): It represents ending indices of corresponding axis in ``axes``. ends (list): It represents end indices of corresponding axis in ``axes``.
axes (list): Axes that `starts` and `ends` apply to . axes (list): Axes that `starts` and `ends` apply to.
shape (list): The shape of the tensor to be sliced.
""" """
def __init__(self, starts, ends, axes): def __init__(self, starts, ends, axes, shape=None):
self._starts = starts self._starts = starts
self._ends = ends self._ends = ends
self._axes = axes self._axes = axes
self._desc = "slice" self._desc = "slice"
self._shape = shape
@property @property
def starts(self): def starts(self):
...@@ -160,7 +220,14 @@ class SliceOpDesc: ...@@ -160,7 +220,14 @@ class SliceOpDesc:
def desc(self): def desc(self):
return self._desc return self._desc
@property
def shape(self):
return self._shape
def __repr__(self): def __repr__(self):
if self._shape is not None:
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}, shape: {self._shape}."
else:
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}." return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}."
...@@ -192,36 +259,84 @@ class Inserter: ...@@ -192,36 +259,84 @@ class Inserter:
"""Insert op required in the reshard process.""" """Insert op required in the reshard process."""
@staticmethod @staticmethod
def insert_send_op(block, idx, tensor, dst, op_role): def insert_cast_op(block, idx, tensor, op_role, tensor_type):
# to avoid name conflict with framework
new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["cast@RESHARD", 'tmp']))
out = block.create_var(name=new_var_name,
dtype=tensor_type,
type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx,
type='cast',
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'in_dtype': tensor.dtype,
'out_dtype': out.dtype,
'op_role': op_role
})
return out
@staticmethod
def insert_send_op(block, idx, tensor, src, dst, op_role):
"""Insert send op into block at the given index.""" """Insert send op into block at the given index."""
op_type = 'send_v2' op_type = 'send_v2'
# use pair comm group
process_group = new_process_group([src, dst])
block._insert_op(idx, block._insert_op(idx,
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
attrs={ attrs={
'ring_id': 0, 'ring_id': process_group.id,
'peer': dst, 'peer': process_group.ranks.index(dst),
'use_calc_stream': True, 'use_calc_stream': True,
'op_role': op_role 'op_role': op_role,
'dynamic_shape': True
}) })
@staticmethod @staticmethod
def insert_recv_op(block, idx, tensor, src, op_role): def insert_recv_op(block, idx, tensor, src, dst, op_role):
"""Insert recv op into block at the given index.""" """Insert recv op into block at the given index."""
op_type = 'recv_v2' op_type = 'recv_v2'
# use pair group
process_group = new_process_group([src, dst])
block._insert_op(idx, block._insert_op(idx,
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [tensor]}, outputs={'Out': [tensor]},
attrs={ attrs={
'ring_id': 0, 'ring_id': process_group.id,
'peer': src, 'peer': process_group.ranks.index(src),
'out_shape': tensor.shape, 'out_shape': tensor.shape,
'dtype': tensor.dtype, 'dtype': tensor.dtype,
'use_calc_stream': True, 'use_calc_stream': True,
'op_role': op_role 'op_role': op_role,
'dynamic_shape': True
}) })
@staticmethod
def insert_reset_lod_op(block, idx, X, Y, op_role):
"""Insert reset_lod op into block at the given index."""
new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["reset_lod@RESHARD", 'tmp']))
reset_lod_out = block.create_var(name=new_var_name,
shape=X.shape,
type=X.type,
dtype=X.dtype,
lod_level=X.lod_level)
block._insert_op(idx,
type="lod_reset",
inputs={
'X': X,
'Y': Y
},
outputs={'Out': reset_lod_out},
attrs={'op_role': op_role})
return reset_lod_out
@staticmethod @staticmethod
def insert_concat_op(block, idx, tensors, axis, op_role): def insert_concat_op(block, idx, tensors, axis, op_role):
"""Insert concat op into block at the given block.""" """Insert concat op into block at the given block."""
...@@ -229,10 +344,18 @@ class Inserter: ...@@ -229,10 +344,18 @@ class Inserter:
attrs = {} attrs = {}
attrs['axis'] = axis attrs['axis'] = axis
attrs['op_role'] = op_role attrs['op_role'] = op_role
helper = LayerHelper('concat', **locals()) # to avoid name conflict with framework
helper = LayerHelper('concat@RESHARD', **locals())
with paddle.static.program_guard(block.program): with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference( out = block.create_var(
dtype=helper.input_dtype()) name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensors[0].dtype,
shape=None,
lod_level=tensors[0].lod_level,
type=tensors[0].type,
persistable=False,
stop_gradient=False)
block._insert_op(idx, block._insert_op(idx,
type='concat', type='concat',
inputs=inputs, inputs=inputs,
...@@ -244,6 +367,72 @@ class Inserter: ...@@ -244,6 +367,72 @@ class Inserter:
def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name, def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name,
op_role): op_role):
"""Insert slice op into block at the given block.""" """Insert slice op into block at the given block."""
# This is a hack to insert split op to get slice tensor
# 1. [128, 128] => [64, 128]: split
# 2. [128, 128] => [128, 128]: assign
# 3. [128, 128] => [64, 64]: slice, it will replaced by multi split
global_shape = tensor.shape
slice_shape = [ends[i] - starts[i] for i in range(len(starts))]
diff_dims = []
for index, item in enumerate(slice_shape):
if item != global_shape[index]:
diff_dims.append(index)
# use assign
if len(diff_dims) == 0:
out = block.create_var(name=new_var_name,
dtype=tensor.dtype,
type=tensor.type,
shape=slice_shape,
lod_level=tensor.lod_level)
inputs = {'X': [tensor]}
outputs = {"Out": [out]}
attrs = {"in_place": False}
block._insert_op(idx,
type="assign",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return out
# use split once
elif len(diff_dims) == 1:
diff_dim = diff_dims[0]
num_or_sections = global_shape[diff_dim] // slice_shape[diff_dim]
axis = diff_dim
cur_idx = starts[diff_dim] // slice_shape[diff_dim]
input_shape = global_shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role}
new_shape = []
for index, item in enumerate(tensor.shape):
if index != axis:
new_shape.append(item)
else:
new_shape.append(item // num_or_sections)
with paddle.static.program_guard(block.program):
outs = [
block.create_var(name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(
['split@RESHARD', 'tmp'])),
dtype=tensor.dtype,
shape=None,
type=tensor.type,
persistable=False,
lod_level=tensor.lod_level,
stop_gradient=False)
for i in range(num_or_sections)
]
out = outs[cur_idx]
op = block._insert_op(idx,
type="split",
inputs=inputs,
outputs={'Out': outs},
attrs=attrs)
return out
# use slice
else:
inputs = {'Input': tensor} inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes))) infer_flags = list(1 for i in range(len(axes)))
attrs = { attrs = {
...@@ -253,28 +442,42 @@ class Inserter: ...@@ -253,28 +442,42 @@ class Inserter:
"infer_flags": infer_flags, "infer_flags": infer_flags,
'op_role': op_role 'op_role': op_role
} }
helper = LayerHelper('slice', **locals())
out = block.create_var(name=new_var_name, out = block.create_var(name=new_var_name,
dtype=tensor.dtype, dtype=tensor.dtype,
type=tensor.type) type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx, block._insert_op(idx,
type="slice", type="slice",
inputs=inputs, inputs=inputs,
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs=attrs) attrs=attrs)
return out return out
@staticmethod @staticmethod
def insert_split_op(block, idx, tensor, num_or_sections, op_role): def insert_split_op(block, idx, tensor, num_or_sections, op_role, axis=0):
"""Insert split op into block at the given index.""" """Insert split op into block at the given index."""
helper = LayerHelper('split', **locals()) helper = LayerHelper('split@RESHARD', **locals())
input_shape = tensor.shape input_shape = tensor.shape
inputs = {'X': tensor} inputs = {'X': tensor}
attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role} attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role}
new_shape = []
for index, item in enumerate(tensor.shape):
if index != axis:
new_shape.append(item)
else:
new_shape.append(item // num_or_sections)
with paddle.static.program_guard(block.program): with paddle.static.program_guard(block.program):
outs = [ outs = [
helper.create_variable_for_type_inference( block.create_var(
dtype=helper.input_dtype()) for i in range(num_or_sections) name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False) for i in range(num_or_sections)
] ]
block._insert_op(idx, block._insert_op(idx,
type="split", type="split",
...@@ -286,9 +489,18 @@ class Inserter: ...@@ -286,9 +489,18 @@ class Inserter:
@staticmethod @staticmethod
def insert_fill_constant_op(block, idx, op_role): def insert_fill_constant_op(block, idx, op_role):
"""Insert fill constant op into block at the given index.""" """Insert fill constant op into block at the given index."""
helper = LayerHelper("fill_constant", **locals()) # to avoid name conflict with framework
helper = LayerHelper('fill_constant@RESHARD', **locals())
# use paddle.int64 as dtype
with paddle.static.program_guard(block.program): with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(dtype="int32") out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=paddle.int64,
shape=None,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
inputs = {} inputs = {}
attrs = {'force_cpu': False} attrs = {'force_cpu': False}
attrs['str_value'] = str(int("1")) attrs['str_value'] = str(int("1"))
...@@ -342,10 +554,18 @@ class Inserter: ...@@ -342,10 +554,18 @@ class Inserter:
# insert c_allgather op # insert c_allgather op
op_type = 'c_allgather' op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals()) # to avoid name conflict with framework
helper = LayerHelper(op_type + "@RESHARD", **locals())
with paddle.static.program_guard(block.program): with paddle.static.program_guard(block.program):
allgather_out = helper.create_variable_for_type_inference( allgather_out = block.create_var(
dtype=tensor.dtype) name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False)
block._insert_op(idx + idx_offset, block._insert_op(idx + idx_offset,
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
...@@ -620,12 +840,14 @@ class Resharder: ...@@ -620,12 +840,14 @@ class Resharder:
batch_size=None): batch_size=None):
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog)) "but got {}.".format(type(auto_parallel_main_prog))
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \ if auto_parallel_startup_prog is not None:
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program or None, " \
"but got {}.".format(type(auto_parallel_startup_prog)) "but got {}.".format(type(auto_parallel_startup_prog))
assert isinstance(rank_id, int), "The type of rank_id should be int, " \ assert isinstance(rank_id, int), "The type of rank_id should be int, " \
"but got {}.".format(type(rank_id)) "but got {}.".format(type(rank_id))
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context)) "but got {}.".format(type(dist_context))
if batch_size is not None: if batch_size is not None:
assert isinstance(batch_size, int), "The type of batch_size should be int, " \ assert isinstance(batch_size, int), "The type of batch_size should be int, " \
"but got {}.".format(type(batch_size)) "but got {}.".format(type(batch_size))
...@@ -639,6 +861,8 @@ class Resharder: ...@@ -639,6 +861,8 @@ class Resharder:
self._has_sent = {} self._has_sent = {}
self._has_recv = {} self._has_recv = {}
self._has_allgather = {} self._has_allgather = {}
# to avoid reshard repeatly
self._has_resharded = {}
@property @property
def auto_parallel_main_prog(self): def auto_parallel_main_prog(self):
...@@ -798,7 +1022,10 @@ class Resharder: ...@@ -798,7 +1022,10 @@ class Resharder:
for op in sub_block.ops: for op in sub_block.ops:
# skip the input and output of operators inserted in the reshard phase # skip the input and output of operators inserted in the reshard phase
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if dist_op: if dist_op or (op.type == "slice" and not dist_op) or (
op.type == "split"
and not dist_op) or (op.type == "assign"
and not dist_op):
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
if var_name not in sub_block_op_outputs: if var_name not in sub_block_op_outputs:
sub_block_op_outputs.append(var_name) sub_block_op_outputs.append(var_name)
...@@ -812,7 +1039,8 @@ class Resharder: ...@@ -812,7 +1039,8 @@ class Resharder:
while_op = op while_op = op
break break
assert while_op is not None if while_op is None:
continue
# find the actual input and output of while op # find the actual input and output of while op
proto = OpProtoHolder.instance().get_op_proto(while_op.type) proto = OpProtoHolder.instance().get_op_proto(while_op.type)
...@@ -821,12 +1049,14 @@ class Resharder: ...@@ -821,12 +1049,14 @@ class Resharder:
if var_name in sub_block_op_inputs: if var_name in sub_block_op_inputs:
new_X.append(var_name) new_X.append(var_name)
assert new_X assert new_X
new_X.sort()
while_op.desc.set_input(proto.inputs[0].name, new_X) while_op.desc.set_input(proto.inputs[0].name, new_X)
new_Out = [] new_Out = []
for var_name in while_op.output("Out"): for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]: for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1: if output_name.find(var_name) != -1:
if output_name not in new_Out:
new_Out.append(output_name) new_Out.append(output_name)
assert new_Out assert new_Out
while_op.desc.set_output(proto.outputs[0].name, new_Out) while_op.desc.set_output(proto.outputs[0].name, new_Out)
...@@ -870,120 +1100,72 @@ class Resharder: ...@@ -870,120 +1100,72 @@ class Resharder:
return True return True
def need_reshard(self, def need_reshard(self, dist_tensor, dist_attr, op_input=True, dist_op=None):
dist_tensor,
dist_op,
actual_process_mesh,
op_input=True):
"""Judge the tensor whether needs to be resharded.""" """Judge the tensor whether needs to be resharded."""
is_reshard = False is_reshard = False
tensor_dist_attr = dist_tensor.dist_attr tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name
tensor_dims_mapping = tensor_dist_attr.dims_mapping tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_process_mesh = tensor_dist_attr.process_mesh tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) # dist_attr is [process_mesh, dims_mapping] and process_mesh is not a union
op_process_mesh = actual_process_mesh op_process_mesh = dist_attr[0]
if op_input: if op_input:
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( op_input_dims_mapping = dist_attr[1]
tensor_name)
if all( if all(
map(lambda x: x is not None, [ map(lambda x: x, [
tensor_dims_mapping, tensor_process_mesh, tensor_dims_mapping, tensor_process_mesh,
op_input_dims_mapping, op_process_mesh op_input_dims_mapping, op_process_mesh
])): ])):
# dims_mapping # judge whether need reshard by dims_mapping
if tensor_dims_mapping != op_input_dims_mapping: if tensor_dims_mapping != op_input_dims_mapping:
if dist_op.serial_op.type == "while": if tensor_process_mesh not in self.dist_context.process_meshes:
sub_block = self.auto_parallel_main_prog.blocks[ # assert whether -1 when union.
dist_op.serial_op.attr("sub_block").id] for item in tensor_dims_mapping:
for op in sub_block.ops: if item != -1:
for var_name in op.input_arg_names: raise ValueError(
if var_name == tensor_name: "The dim must be -1 when tensor process mesh is a union."
dist_op_attr = self.dist_context.get_dist_op_for_program( )
op).dist_attr # tensor process_mesh: [0, 1, 2, 3], dims_mapping: [-1, -1]
var_dims_mapping = dist_op_attr.get_input_dims_mapping( # op process_mesh: [4, 5], dims_mapping: [0, -1]
var_name) # reshard is not supported such as above
if var_dims_mapping != tensor_dims_mapping: if not is_reshard:
is_reshard = True return is_reshard
break
else:
is_reshard = True
# process_mesh
if tensor_process_mesh != op_process_mesh:
# when processes length is not the same, the dims mapping must be replicative now
if len(tensor_process_mesh.processes) != len(
op_process_mesh.processes):
assert self.is_unshard(tensor_dims_mapping)
assert self.is_unshard(op_input_dims_mapping)
else: else:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError( raise ValueError(
"Bool var is not supported reshard.") "it is not supported that tensor process mesh is a union and needs reshard."
)
# for while op, it should find the process mesh of op actually used the tensor as input
if dist_op.serial_op.type == "while":
sub_block = self.auto_parallel_main_prog.blocks[
dist_op.serial_op.attr("sub_block").id]
for op in sub_block.ops:
for var_name in op.input_arg_names:
if var_name == tensor_name:
dist_op_attr = self.dist_context.get_dist_op_for_program(
op).dist_attr
process_mesh = dist_op_attr.process_mesh
if process_mesh == op_process_mesh:
is_reshard = True is_reshard = True
break
else: # judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh:
is_reshard = True is_reshard = True
else: else:
op_output_dims_mapping = op_dist_attr.get_output_dims_mapping( op_output_dims_mapping = dist_attr[1]
tensor_name)
if all( if all(
map(lambda x: x is not None, [ map(lambda x: x, [
tensor_dims_mapping, tensor_process_mesh, tensor_dims_mapping, tensor_process_mesh,
op_output_dims_mapping, op_process_mesh op_output_dims_mapping, op_process_mesh
])): ])):
if tensor_process_mesh != op_process_mesh:
if dist_tensor.serial_tensor.dtype == paddle.bool:
raise ValueError("Bool var is not supported reshard.")
is_reshard = True
if tensor_dims_mapping != op_output_dims_mapping: if tensor_dims_mapping != op_output_dims_mapping:
raise ValueError( raise ValueError(
"It is not supported that tensor dims mapping is different from op output dims mapping." "It is not supported that tensor dims mapping is different from op output dims mapping."
) )
if tensor_process_mesh != op_process_mesh:
is_reshard = True
return is_reshard return is_reshard
def get_process_meshes(self, op):
"""Get all process meshes when op has sub block."""
assert op.has_attr("sub_block")
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
process_meshes = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh not in process_meshes and process_mesh != op_process_mesh:
process_meshes.append(process_mesh)
if not process_meshes:
process_meshes.append(op_process_mesh)
return process_meshes
def get_op_process_meshes(self, op): def get_op_process_meshes(self, op):
"""Get sub process meshes of the given op if op process mesh is a union."""
process_meshes = [] process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
op_process_mesh = dist_op.dist_attr.process_mesh op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes: for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (set( if set(process_mesh.processes) & (set(
op_process_mesh.processes)) and len( op_process_mesh.processes)) and len(
process_mesh.processes) <= len( process_mesh.processes) < len(
op_process_mesh.processes): op_process_mesh.processes):
process_meshes.append(process_mesh) process_meshes.append(process_mesh)
...@@ -993,39 +1175,14 @@ class Resharder: ...@@ -993,39 +1175,14 @@ class Resharder:
return process_meshes return process_meshes
def get_while_op_actual_process_mesh(self, op): def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False):
"""Get the while op actual Process mesh corresponding to rank"""
assert op.type == "while"
while_op_process_mesh = self.dist_context.get_dist_op_for_program(
op).dist_attr.process_mesh
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
actual_process_mesh = None
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
if process_mesh == while_op_process_mesh:
continue
if self.rank_id in process_mesh.processes:
raw_process_mesh = process_mesh
break
if actual_process_mesh is None and self.rank_id in while_op_process_mesh.processes:
actual_process_mesh = while_op_process_mesh
assert actual_process_mesh is not None
return actual_process_mesh
def find_op_desc_seq(self, dist_tensor, dist_op, actual_process_mesh):
""" """
Find the op description sequence to reshard the source tensor for matching the op requirement. Find the op description sequence to reshard the source tensor for matching the op requirement.
Args: Args:
dist_tensor (DistributedTensor): A distributed tensor. dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator. dist_attr (list): A list contains process_mesh and dims_mapping such as [process_mesh, dims_mapping].
actual_process_mesh (ProcessMesh): The actual op process mesh. serial (bool): If serial is true, the dist tensor and dist op come from serial program. Otherwise, they come from auto program.
Returns: Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
...@@ -1034,24 +1191,26 @@ class Resharder: ...@@ -1034,24 +1191,26 @@ class Resharder:
tensor_dist_attr = dist_tensor.dist_attr tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology source_process_shape = source_process_mesh.topology
op_dist_attr = dist_op.dist_attr target_process_mesh = dist_attr[0]
target_process_mesh = actual_process_mesh target_dims_mapping = dist_attr[1]
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology target_process_shape = target_process_mesh.topology
if source_tensor.shape[0] < 0: if source_tensor.shape[0] < 0:
assert source_tensor.shape[0] == -1
new_shape = list(source_tensor.shape) new_shape = list(source_tensor.shape)
new_shape[0] = self.batch_size new_shape[0] = self.batch_size
source_tensor.desc.set_shape(new_shape) source_tensor.desc.set_shape(new_shape)
complete_shape = Resharder.compute_complete_shape( complete_shape = Resharder.compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping) source_tensor.shape, source_process_shape,
source_dims_mapping) if not serial else source_tensor.shape
op_desc_seq = {} op_desc_seq = {}
# TODO: if the target process group has the same process with source process group # TODO: if the target process group has the same process with source process group
...@@ -1060,13 +1219,14 @@ class Resharder: ...@@ -1060,13 +1219,14 @@ class Resharder:
set(source_process_group)): set(source_process_group)):
pass pass
# in the different process group, it will use send, recv, concat and slice op
elif target_process_group != source_process_group: elif target_process_group != source_process_group:
partition_process_mapping_list = [] partition_process_mapping_list = []
for source_process in source_process_group: for source_process in source_process_group:
# get partition index of source process
source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \ source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \
source_process_shape, source_process_group) source_process_shape, source_process_group)
if not partition_process_mapping_list: if not partition_process_mapping_list:
# the item in partition_process_mapping_list is source_partition_index, which processes and whether has been used
partition_process_mapping_list.append( partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]]) [source_partition_index, [source_process], [False]])
else: else:
...@@ -1076,6 +1236,7 @@ class Resharder: ...@@ -1076,6 +1236,7 @@ class Resharder:
[item[1] for item in partition_process_mapping_list]) [item[1] for item in partition_process_mapping_list])
has_used = list( has_used = list(
[item[2] for item in partition_process_mapping_list]) [item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1: if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index) index = partition_list.index(source_partition_index)
process_list[index].append(source_process) process_list[index].append(source_process)
...@@ -1085,6 +1246,7 @@ class Resharder: ...@@ -1085,6 +1246,7 @@ class Resharder:
[source_partition_index, [source_process], [False]]) [source_partition_index, [source_process], [False]])
for target_process in target_process_group: for target_process in target_process_group:
# has_sent means the source_partition_index has been sent to target_process
has_sent = [] has_sent = []
target_partition_index = Resharder.compute_partition_index( target_partition_index = Resharder.compute_partition_index(
target_process, complete_shape, target_dims_mapping, target_process, complete_shape, target_dims_mapping,
...@@ -1114,6 +1276,7 @@ class Resharder: ...@@ -1114,6 +1276,7 @@ class Resharder:
has_used[i] = True has_used[i] = True
break break
i += 1 i += 1
if i == len(has_used): if i == len(has_used):
has_used = list(map(lambda x: False, has_used)) has_used = list(map(lambda x: False, has_used))
to_send_process = process_list[0] to_send_process = process_list[0]
...@@ -1127,10 +1290,16 @@ class Resharder: ...@@ -1127,10 +1290,16 @@ class Resharder:
all_partition_index_list.append(source_partition_index) all_partition_index_list.append(source_partition_index)
# append send and recv op desc # append send and recv op desc
is_bool = (
dist_tensor.serial_tensor.dtype == paddle.bool)
send_op_desc = SendOpDesc(source_partition_index, send_op_desc = SendOpDesc(source_partition_index,
target_process) to_send_process,
target_process,
is_bool=is_bool)
recv_op_desc = RecvOpDesc(source_partition_index, recv_op_desc = RecvOpDesc(source_partition_index,
to_send_process) to_send_process,
target_process,
is_bool=is_bool)
op_desc_seq[to_send_process].append(send_op_desc) op_desc_seq[to_send_process].append(send_op_desc)
op_desc_seq[target_process].append(recv_op_desc) op_desc_seq[target_process].append(recv_op_desc)
has_sent.append(source_partition_index) has_sent.append(source_partition_index)
...@@ -1146,16 +1315,24 @@ class Resharder: ...@@ -1146,16 +1315,24 @@ class Resharder:
slice_ends = [] slice_ends = []
slices_axes = [] slices_axes = []
concatenated_partition_index = partition_index_list[0] concatenated_partition_index = partition_index_list[0]
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index): for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] - slice_starts.append(target_partition_index[idx][0] -
item[0]) item[0])
slice_ends.append(target_partition_index[idx][1] - item[0]) slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx) slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
op_desc_seq[target_process].append( op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes)) SliceOpDesc(slice_starts,
slice_ends,
slices_axes,
shape=to_slice_tensor_shape))
# in the same process group, it will use allgahther and slice op # in the same process group, it will use allgahther and slice op.
else: else:
# NOTE: It just supports even partition scene.
partition_index_list = [] partition_index_list = []
all_partition_index_list = [] all_partition_index_list = []
process_index = [] process_index = []
...@@ -1191,17 +1368,21 @@ class Resharder: ...@@ -1191,17 +1368,21 @@ class Resharder:
slice_ends.append(item[1]) slice_ends.append(item[1])
slices_axes.append(idx) slices_axes.append(idx)
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc(starts=slice_starts, slice_op_desc = SliceOpDesc(starts=slice_starts,
ends=slice_ends, ends=slice_ends,
axes=slices_axes) axes=slices_axes,
op_desc_seq[process] = [AllGatherOpDesc(group=group), shape=to_slice_tensor_shape)
allgather_shape = None if not serial else dist_tensor.local_sizes(
rank=process)
op_desc_seq[process] = [AllGatherOpDesc(group=group, shape=allgather_shape, is_bool=(source_tensor.dtype == paddle.bool)),
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \ ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \
if len(group) > 1 else [slice_op_desc] if len(group) > 1 else [slice_op_desc]
return op_desc_seq return op_desc_seq
def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op, def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
actual_process_mesh): dist_attr):
"""Parse op desc sequence and insert op in the block""" """Parse op desc sequence and insert op in the block"""
tensor_list = [] tensor_list = []
partition_tensor_list = [] partition_tensor_list = []
...@@ -1226,6 +1407,25 @@ class Resharder: ...@@ -1226,6 +1407,25 @@ class Resharder:
self.has_allgather[var_name] = [] self.has_allgather[var_name] = []
if not self.has_allgather[var_name] or op_desc.group not in list( if not self.has_allgather[var_name] or op_desc.group not in list(
map(lambda x: x[0], self.has_allgather[var_name])): map(lambda x: x[0], self.has_allgather[var_name])):
if op_desc.is_bool:
# for bool data allgather, cast to int64 -> allgather -> cast bool
out_cast = Inserter.insert_cast_op(
block, idx, source_tensor,
reshard_op.attr('op_role'), paddle.int64)
tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx + 1, out_cast, op_desc.group,
reshard_op.attr('op_role'))
idx += idx_offset
tensor_name_list = []
for var in tensor_list:
out_cast = Inserter.insert_cast_op(
block, idx, var, reshard_op.attr('op_role'),
paddle.bool)
tensor_name_list.append(out_cast.name)
idx += 1
self.has_allgather[var_name].append(
[op_desc.group, tensor_name_list])
else:
tensor_list, idx_offset = Inserter.insert_allgather_op( tensor_list, idx_offset = Inserter.insert_allgather_op(
block, idx, source_tensor, op_desc.group, block, idx, source_tensor, op_desc.group,
reshard_op.attr('op_role')) reshard_op.attr('op_role'))
...@@ -1249,8 +1449,17 @@ class Resharder: ...@@ -1249,8 +1449,17 @@ class Resharder:
if var_name not in self.has_sent.keys(): if var_name not in self.has_sent.keys():
self.has_sent[var_name] = [] self.has_sent[var_name] = []
if op_desc.dst not in self.has_sent[var_name]: if op_desc.dst not in self.has_sent[var_name]:
if op_desc.is_bool:
out_cast = Inserter.insert_cast_op(
block, idx, source_tensor,
reshard_op.attr('op_role'), paddle.int64)
Inserter.insert_send_op(block, idx + 1, out_cast,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
idx += 2
else:
Inserter.insert_send_op(block, idx, source_tensor, Inserter.insert_send_op(block, idx, source_tensor,
op_desc.dst, op_desc.src, op_desc.dst,
reshard_op.attr('op_role')) reshard_op.attr('op_role'))
idx += 1 idx += 1
self.has_sent[var_name].append(op_desc.dst) self.has_sent[var_name].append(op_desc.dst)
...@@ -1263,14 +1472,55 @@ class Resharder: ...@@ -1263,14 +1472,55 @@ class Resharder:
shape = [] shape = []
for index in partition_index: for index in partition_index:
shape.append(index[1] - index[0]) shape.append(index[1] - index[0])
if op_desc.is_bool:
# for bool data, recv int64 -> cast to bool
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
lod_level=source_tensor.lod_level,
dtype=paddle.int64,
type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src, op_desc.dst,
reshard_op.attr('op_role'))
out_cast = Inserter.insert_cast_op(
block, idx + 1, recv_tensor,
reshard_op.attr('op_role'), paddle.bool)
tensor_list.append(out_cast)
idx += 2
self.has_recv[var_name][op_desc.src] = out_cast
else:
recv_tensor = block.create_var( recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"), name=unique_name.generate(var_name + "@recv"),
shape=shape, shape=shape,
lod_level=source_tensor.lod_level,
dtype=source_tensor.dtype, dtype=source_tensor.dtype,
type=source_tensor.type) type=source_tensor.type)
Inserter.insert_recv_op(block, idx, recv_tensor, Inserter.insert_recv_op(block, idx, recv_tensor,
op_desc.src, op_desc.src, op_desc.dst,
reshard_op.attr('op_role')) reshard_op.attr('op_role'))
# for lod tensor, need reset lod after received
if recv_tensor.lod_level != 0:
set_lod = False
# use data lod to reset tensor lod
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == recv_tensor.lod_level:
reset_lod_out = Inserter.insert_reset_lod_op(
block, idx + 1, recv_tensor,
tmp_var, reshard_op.attr('op_role'))
tensor_list.append(reset_lod_out)
idx += 2
self.has_recv[var_name][
op_desc.src] = reset_lod_out
set_lod = True
break
if set_lod:
break
assert set_lod is True
else:
tensor_list.append(recv_tensor) tensor_list.append(recv_tensor)
idx += 1 idx += 1
self.has_recv[var_name][op_desc.src] = recv_tensor self.has_recv[var_name][op_desc.src] = recv_tensor
...@@ -1303,90 +1553,226 @@ class Resharder: ...@@ -1303,90 +1553,226 @@ class Resharder:
new_var_name=new_name, new_var_name=new_name,
op_role=reshard_op.attr('op_role')) op_role=reshard_op.attr('op_role'))
process_mesh = dist_attr[0]
dims_mapping = dist_attr[1]
tensor_attr = TensorDistributedAttribute() tensor_attr = TensorDistributedAttribute()
process_mesh = actual_process_mesh
dims_mapping = self.dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.dims_mapping = dims_mapping tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh tensor_attr.process_mesh = process_mesh
self.dist_context.set_tensor_dist_attr_for_program( self.dist_context.set_tensor_dist_attr_for_program(
target_tensor, tensor_attr) target_tensor, tensor_attr)
if op.type == "while": if matched_op.type == "while":
# var_reshard_mapping means the while op input need be changed to # var_reshard_mapping means the while op input need be changed to
if "var_reshard_mapping" not in Resharder.while_block_info[ if "var_reshard_mapping" not in Resharder.while_block_info[
op.attr("sub_block").id].keys(): op.attr("sub_block").id].keys():
Resharder.while_block_info[op.attr( Resharder.while_block_info[op.attr(
"sub_block").id]["var_reshard_mapping"] = {} "sub_block").id]["var_reshard_mapping"] = {}
if var_name not in Resharder.while_block_info[op.attr(
"sub_block").id]["var_reshard_mapping"].keys():
Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = []
Resharder.while_block_info[op.attr("sub_block").id][ Resharder.while_block_info[op.attr("sub_block").id][
"var_reshard_mapping"][var_name] = target_tensor.name "var_reshard_mapping"][var_name].append(
[dist_attr, target_tensor.name])
# rename op input name according to new name # rename op input name according to new name
for op in block.ops: for op in block.ops:
# just for while op
while_op_X_append = []
for name in op.input_arg_names: for name in op.input_arg_names:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program( op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
op) op)
if name == var_name and op_dist_attr is not None: if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id(): if op.desc.id() == matched_op.desc.id():
op.desc._rename_input(name, target_tensor.name) if matched_op.type == "while":
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping) new_name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None) if old_name in op_dist_attr._inputs_dist_attrs:
op_dist_attr.del_input_dist_attr(
old_name)
while_op_X_append.append(new_name)
continue
else:
op.desc._rename_input(
name, target_tensor.name)
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping)
op_dist_attr.del_input_dist_attr(old_name)
continue continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh = op_dist_attr.process_mesh op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name) var_name)
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping: if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name) op.desc._rename_input(name, target_tensor.name)
old_name = name
new_name = target_tensor.name
assert old_name != new_name
op_input_dist_attr = op_dist_attr.get_input_dist_attr(
old_name)
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr)
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping) new_name, dims_mapping)
op_dist_attr.set_input_dist_attr(name, None) op_dist_attr.del_input_dist_attr(old_name)
def reshard(self): # for while op, the input X should reset
for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks): if while_op_X_append:
if block_idx in Resharder.while_block_info: proto = OpProtoHolder.instance().get_op_proto(op.type)
if "var_reshard_mapping" in Resharder.while_block_info[ op.desc.set_input(proto.inputs[0].name,
block_idx]: op.input("X") + while_op_X_append)
def _get_while_op_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
input_attrs = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
dist_attr = dist_op.dist_attr
for name in op.input_arg_names:
if name == var_name:
process_mesh = dist_attr.process_mesh
input_dims_mapping = dist_attr.get_input_dims_mapping(
var_name)
has_exist = False
for input_attr in input_attrs:
if process_mesh == input_attr[
0] and input_dims_mapping == input_attr[1]:
has_exist = True
break
if not has_exist:
input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs
def _get_common_op_input_attrs(self, op, var_name):
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
dist_attr = dist_op.dist_attr
op_process_mesh = dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (set(
op_process_mesh.processes)) and len(
process_mesh.processes) < len(
op_process_mesh.processes):
process_meshes.append(process_mesh)
# it means that the process mesh is not a union when process meshes is none
if not process_meshes:
process_meshes.append(op_process_mesh)
input_dims_mapping = dist_attr.get_input_dims_mapping(var_name)
input_attrs = []
for process_mesh in process_meshes:
input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs
def get_op_input_attrs(self, op, var_name):
op_input_attrs = []
if op.type == "while":
op_input_attrs = self._get_while_op_input_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
assert op_input_attrs
return op_input_attrs
def _remove_global_process_mesh(self):
"""Remove global process mesh from dist_context.process_meshes"""
processes = set()
process_mesh_count = len(self.dist_context.process_meshes)
if process_mesh_count > 1:
global_process_mesh_idx = None
for process_mesh in self.dist_context.process_meshes:
for process in process_mesh.processes:
processes.add(process)
for idx, process_mesh in enumerate(
self.dist_context.process_meshes):
if len(set(process_mesh.processes)) == len(processes):
global_process_mesh_idx = idx
break
if global_process_mesh_idx is not None:
self.dist_context.process_meshes.pop(idx)
def _change_subblock_op_input_and_output(self, block_idx, block):
if "var_reshard_mapping" in Resharder.while_block_info[block_idx]:
var_reshard_mapping = Resharder.while_block_info[block_idx][ var_reshard_mapping = Resharder.while_block_info[block_idx][
"var_reshard_mapping"] "var_reshard_mapping"]
for op in block.ops: for op in block.ops:
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
if var_name in var_reshard_mapping: if var_name in var_reshard_mapping:
op.desc._rename_input( # in while sub block, the union process mesh is not split before reshard sub block
var_name, var_reshard_mapping[var_name]) dist_op = self.dist_context.get_dist_op_for_program(op)
dist_attr = dist_op.dist_attr
target_name = None
for item in var_reshard_mapping[var_name]:
if dist_attr.process_mesh == item[0][
0] and dist_attr.get_input_dims_mapping(
var_name) == item[0][1]:
target_name = item[1]
break
if target_name is None:
continue
else:
op.desc._rename_input(var_name, target_name)
dist_op = self.dist_context.get_dist_op_for_program( dist_op = self.dist_context.get_dist_op_for_program(
op) op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[ old_name = var_name
block_idx]["actual_process_mesh"]: new_name = target_name
dims_mapping = op_dist_attr.get_input_dims_mapping( assert old_name != new_name
var_name) op_input_dist_attr = op_dist_attr.get_input_dist_attr(
op_dist_attr.set_input_dims_mapping( old_name)
var_reshard_mapping[var_name],
dims_mapping)
op_dist_attr.set_input_dist_attr( op_dist_attr.set_input_dist_attr(
var_name, None) new_name, op_input_dist_attr)
op_dist_attr.del_input_dist_attr(old_name)
# the outputs also need to be renamed when the output name is the same with input name # the outputs also need to be renamed when the output name is the same with input name in inplace op
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
# if the tensor has been resharded multiply, it is not supported now.
if var_name in var_reshard_mapping: if var_name in var_reshard_mapping:
op.desc._rename_output( if len(var_reshard_mapping[var_name]) > 1:
var_name, var_reshard_mapping[var_name]) raise ValueError(
dist_op = self.dist_context.get_dist_op_for_program( "The scene is not supported that the output is inplaced and the tensor has been resharded multiply when as input."
op) )
target_name = var_reshard_mapping[var_name][0][1]
op.desc._rename_output(var_name, target_name)
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
if op_dist_attr.process_mesh == Resharder.while_block_info[ old_name = var_name
block_idx]["actual_process_mesh"]: new_name = target_name
dims_mapping = op_dist_attr.get_output_dims_mapping( assert old_name != new_name
var_name) op_output_dist_attr = op_dist_attr.get_output_dist_attr(
op_dist_attr.set_output_dims_mapping( old_name)
var_reshard_mapping[var_name],
dims_mapping)
op_dist_attr.set_output_dist_attr( op_dist_attr.set_output_dist_attr(
var_name, None) new_name, op_output_dist_attr)
op_dist_attr.del_output_dist_attr(old_name)
def _reshard_input(self, block):
idx = 0 idx = 0
while idx < len(block.ops): while idx < len(block.ops):
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
...@@ -1398,46 +1784,62 @@ class Resharder: ...@@ -1398,46 +1784,62 @@ class Resharder:
dist_op = self.dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None: if dist_op is not None:
process_meshes = [] op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while": if op.type == "while":
if not self.is_condition_replicative(op): if not self.is_condition_replicative(op):
raise ValueError( raise ValueError(
"Please check the condition due to the dims mapping is not replicative." "Please check the condition due to the dims mapping is not replicative."
) )
process_meshes = self.get_process_meshes(op) if op.attr(
assert process_meshes "sub_block").id not in Resharder.while_block_info:
if op.attr("sub_block" Resharder.while_block_info[op.attr("sub_block").id] = {}
).id not in Resharder.while_block_info:
Resharder.while_block_info[op.attr(
"sub_block").id] = {}
Resharder.while_block_info[op.attr( Resharder.while_block_info[op.attr(
"sub_block").id]["op_id"] = op.desc.id() "sub_block").id]["op_id"] = op.desc.id()
Resharder.while_block_info[op.attr("sub_block").id][
"actual_process_mesh"] = self.get_while_op_actual_process_mesh(
op)
else:
process_meshes = self.get_op_process_meshes(op)
input_vars = None
if op.type == "while": if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X") input_var_names = op.input("X")
else: else:
input_var_names = op.input_arg_names input_var_names = op.input_arg_names
# to avoid while op X order different
input_var_names.sort()
idx_offset = 0 idx_offset = 0
for var_name in op.input_arg_names: for var_name in input_var_names:
# skip lod_tensor_blocking_queue_0 # skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0": if var_name == "lod_tensor_blocking_queue_0":
continue continue
var = get_var_with_recursion( var = get_var_with_recursion(var_name, block,
var_name, block, self.auto_parallel_main_prog) self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program( dist_tensor = self.dist_context.get_dist_tensor_for_program(
var) var)
for process_mesh in process_meshes:
# judge whether union tensor dims_mapping all -1
is_union_process_mesh_tensor = False
if dist_tensor.dist_attr.process_mesh not in self.dist_context.process_meshes and self.dist_context.process_meshes:
is_union_process_mesh_tensor = True
assert dist_tensor.dist_attr.dims_mapping.count(
-1) == len(dist_tensor.dist_attr.dims_mapping)
op_input_attrs = self.get_op_input_attrs(op, var_name)
for input_attr in op_input_attrs:
input_process_mesh = None
# deal with union tensor
if is_union_process_mesh_tensor:
# if op process mesh is subset of union tensor process mesh, need no reshard
if set(input_attr[0].processes) <= set(
dist_tensor.dist_attr.process_mesh.processes
):
continue
if dist_tensor is not None and self.need_reshard( if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh): dist_tensor, input_attr):
reshard_op_desc = self.find_op_desc_seq( reshard_op_desc = self.find_op_desc_seq(
dist_tensor, dist_op, process_mesh) dist_tensor, input_attr)
self.parse_op_desc(block, reshard_op_desc, self.parse_op_desc(block, reshard_op_desc, var_name,
var_name, op, process_mesh) op, input_attr)
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count pre_op_count = cur_op_count
...@@ -1445,12 +1847,113 @@ class Resharder: ...@@ -1445,12 +1847,113 @@ class Resharder:
else: else:
idx += 1 idx += 1
def _hadnle_recv(self, block, idx, var, op, send_rank, recv_rank):
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
if var.dtype == paddle.bool:
recv_cast_out = block.create_var(
name=unique_name.generate(var.name + "@recv"),
shape=var.shape,
lod_level=var.lod_level,
dtype=paddle.int64,
type=var.type)
Inserter.insert_recv_op(block, idx + 1,
recv_cast_out, send_rank, recv_rank,
op.attr('op_role'))
reset_lod_out = None
if var.lod_level != 0:
set_lod = False
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == var.lod_level:
reset_lod_out = block.create_var(
name=unique_name.generate(var.name +
"@RESETLOD"),
shape=recv_cast_out.shape,
type=recv_cast_out.type,
dtype=recv_cast_out.dtype,
lod_level=recv_cast_out.lod_level)
idx += 1
block._insert_op(
idx,
type="lod_reset",
inputs={
'X': recv_cast_out,
'Y': tmp_var
},
outputs={'Out': reset_lod_out},
attrs={'op_role': op.attr("op_role")})
set_lod = True
break
if set_lod:
break
assert set_lod is True
# cast int64 to bool
block._insert_op(idx + 2,
type='cast',
inputs={
'X': [recv_cast_out] if
reset_lod_out is None else [reset_lod_out]
},
outputs={'Out': [var]},
attrs={
'in_dtype': recv_cast_out.dtype,
'out_dtype': var.dtype,
'op_role': op.attr('op_role')
})
else:
if var.lod_level != 0:
recv_out = block.create_var(
name=unique_name.generate(var.name + "@recv"),
shape=var.shape,
lod_level=var.lod_level,
dtype=var.int64,
type=var.type)
Inserter.insert_recv_op(block, idx + 1, recv_out, send_rank,
recv_rank, op.attr('op_role'))
set_lod = False
for tmp_block in self.auto_parallel_main_prog.blocks:
for tmp_var_name in tmp_block.vars:
tmp_var = tmp_block.vars[tmp_var_name]
if tmp_var.is_data and tmp_var.lod_level == var.lod_level:
idx += 1
block._insert_op(
idx,
type="lod_reset",
inputs={
'X': recv_out,
'Y': tmp_var
},
outputs={'Out': var},
attrs={'op_role': op.attr("op_role")})
set_lod = True
break
if set_lod:
break
assert set_lod is True
else:
Inserter.insert_recv_op(block, idx + 1, var, send_rank,
recv_rank, op.attr('op_role'))
def _handle_send(self, block, idx, var, op, send_rank, recv_rank):
if var.dtype == paddle.bool:
cast_out = Inserter.insert_cast_op(block, idx + 1, var,
op.attr('op_role'), paddle.int64)
Inserter.insert_send_op(block, idx + 2, cast_out, send_rank,
recv_rank, op.attr('op_role'))
else:
Inserter.insert_send_op(block, idx + 1, var, send_rank, recv_rank,
op.attr('op_role'))
def _reshard_output(self, block):
# insert send and recv op if output process mesh is different from tensor process mesh # insert send and recv op if output process mesh is different from tensor process mesh
idx = 0 idx = 0
# skip reader and ops whose process mesh is union # skip reader and ops whose process mesh is union
skip_ops = [ skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "create_py_reader", "create_double_buffer_reader", "read", "while",
"while", "write_to_array", "read_from_array" "write_to_array", "read_from_array"
] ]
global _g_special_ops global _g_special_ops
skip_ops += _g_special_ops skip_ops += _g_special_ops
...@@ -1459,33 +1962,98 @@ class Resharder: ...@@ -1459,33 +1962,98 @@ class Resharder:
op = block.ops[idx] op = block.ops[idx]
dist_op = self.dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
if dist_op is not None and op.type not in skip_ops: if dist_op is not None and op.type not in skip_ops:
idx_offset = 0
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
var = get_var_with_recursion( var = get_var_with_recursion(var_name, block,
var_name, block, self.auto_parallel_main_prog) self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program( dist_tensor = self.dist_context.get_dist_tensor_for_program(
var) var)
process_mesh = dist_op.dist_attr.process_mesh tensor_process_mesh = dist_tensor.dist_attr.process_mesh
output_attr = [
dist_op.dist_attr.process_mesh,
dist_op.dist_attr.get_output_dims_mapping(var_name)
]
if dist_tensor is not None and self.need_reshard( if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_op, process_mesh, False): dist_tensor, output_attr, False):
for index, item in enumerate( tensor_processes = set(
dist_op.dist_attr.process_mesh.processes): tensor_process_mesh.processes) - (
recv_rank = dist_tensor.dist_attr.process_mesh.processes[ set(tensor_process_mesh.processes)
index] & set(output_attr[0].processes))
if tensor_processes:
if len(tensor_processes) != len(
output_attr[0].processes):
if dist_tensor.dist_attr.dims_mapping.count(
-1) != len(
dist_tensor.dist_attr.dims_mapping
) or output_attr[1].count(-1) != len(
output_attr[1]):
raise ValueError(
"The dims_mapping must be -1")
else:
for index, tensor_process in enumerate(
tensor_processes):
recv_rank = tensor_process
actual_index = index
if index >= len(
output_attr[0].processes):
actual_index = (
index -
len(output_attr[0].processes)
) % len(output_attr[0].processes)
item = output_attr[0].processes[
actual_index]
if recv_rank == item:
continue
if self.rank_id == item: if self.rank_id == item:
Inserter.insert_send_op( # if send bool data, cast then send
block, idx + 1, var, recv_rank, self._handle_send(
op.attr('op_role')) block, idx, var, op, item,
recv_rank)
if self.rank_id == recv_rank: if self.rank_id == recv_rank:
Inserter.insert_recv_op( # if recv bool data, recv then cast
block, idx + 1, var, item, self._hadnle_recv(
op.attr('op_role')) block, idx, var, op, item,
recv_rank)
else:
for index, tensor_process in enumerate(
tensor_processes):
recv_rank = tensor_process
item = output_attr[0].processes[index]
if recv_rank == item:
continue
if self.rank_id == item:
# if send bool data, cast then send
self._handle_send(
block, idx, var, op, item,
recv_rank)
if self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
block, idx, var, op, item,
recv_rank)
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count pre_op_count = cur_op_count
idx = idx + idx_offset + 1 idx = idx + idx_offset + 1
else: else:
idx += 1 idx += 1
def reshard(self):
self._remove_global_process_mesh()
for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks):
# change the var_name before resharding sub block
if block_idx in Resharder.while_block_info:
self._change_subblock_op_input_and_output(block_idx, block)
# reshard input
self._reshard_input(block)
# reshard output
# NOTE: Only support that insert send and recv op if output process mesh is different from tensor process mesh
self._reshard_output(block)
# remove no need vars and ops in the main program # remove no need vars and ops in the main program
Remover.remove_no_need_in_main(self.auto_parallel_main_prog, Remover.remove_no_need_in_main(self.auto_parallel_main_prog,
self.dist_context, self.rank_id, self.dist_context, self.rank_id,
......
...@@ -1419,7 +1419,10 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1419,7 +1419,10 @@ def get_standalone_cost_data(distributed_programs):
} }
standalone_cost_data = [] standalone_cost_data = []
not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"] # skip ops
not_enum_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "assign"
]
for distributed_program in distributed_programs: for distributed_program in distributed_programs:
cost_data = {} cost_data = {}
vars = distributed_program.global_block().vars vars = distributed_program.global_block().vars
......
...@@ -27,7 +27,10 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di ...@@ -27,7 +27,10 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice'] _skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign'
]
# update here to support new optimizers # update here to support new optimizers
_supported_optimizer_type = [ _supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum", "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
......
...@@ -107,7 +107,9 @@ class TestDataUnshard(unittest.TestCase): ...@@ -107,7 +107,9 @@ class TestDataUnshard(unittest.TestCase):
input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32") input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32")
label_data = np.random.randint(0, 10, [2, 8]).astype("float32") label_data = np.random.randint(0, 10, [2, 8]).astype("float32")
fetchs = [loss.name, 'input@RESHARD_0'] fetchs = [loss.name, 'split@RESHARD.tmp_0'] if worker_index == 0 else [
loss.name, 'split@RESHARD.tmp_1'
]
loss_np, shard_data_np = exe.run(distributed_main_program, loss_np, shard_data_np = exe.run(distributed_main_program,
feed={ feed={
"input": input_data, "input": input_data,
......
...@@ -28,7 +28,7 @@ from paddle.distributed import fleet ...@@ -28,7 +28,7 @@ from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.process_group import _g_process_group_map from paddle.distributed.auto_parallel.process_group import _g_process_group_map, ProcessGroup
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -307,6 +307,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -307,6 +307,7 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
...@@ -326,10 +327,10 @@ class TestMLPReshard(unittest.TestCase): ...@@ -326,10 +327,10 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id, True) train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
self.assertTrue(check_initialization(dist_startup_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册