未验证 提交 831db343 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Add c_concat pass for reshard (#47809)

* add c_concat pass for reshard

* add unittest
上级 d01109fc
......@@ -92,6 +92,42 @@ class AllGatherOpDesc:
return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}."
class AllGatherConcatOpDesc:
"""
Describe the c_concat op in the reshard phase.
Args:
group (list): Process group.
shape (list): The tensor shape.
is_bool (bool): Whether c_concat bool data. Default: False.
"""
def __init__(self, group, shape, is_bool=False):
self._group = group
self._desc = "c_concat"
self._shape = shape
self._is_bool = is_bool
@property
def is_bool(self):
return self._is_bool
@property
def group(self):
return self._group
@property
def desc(self):
return self._desc
@property
def shape(self):
return self._shape
def __repr__(self):
return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}."
class SendOpDesc:
"""
Describe the send op in the reshard phase.
......@@ -640,6 +676,46 @@ class Inserter:
tensor_list.extend(split_out)
return tensor_list, idx_offset
@staticmethod
def insert_c_concat_op(block, idx, tensor, ranks, op_role):
"""Insert c_concat op into block at the given index."""
group = new_process_group(ranks)
idx_offset = 0
# insert c_concat op
op_type = 'c_concat'
# to avoid name conflict with framework
helper = LayerHelper(op_type + "@RESHARD", **locals())
with paddle.static.program_guard(block.program):
c_concat_out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])
),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False,
)
cur_rank = paddle.distributed.get_rank()
c_concat_op = block._insert_op(
idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [c_concat_out]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
'nranks': group.nranks,
'op_role': op_role,
'rank': group.ranks.index(cur_rank) if cur_rank in ranks else 0,
},
)
c_concat_op._set_attr('op_namescope', "/auto_parallel/reshard")
return c_concat_out
@staticmethod
def concat_partitions_with_op(
partition_tensor_list, tensor, partition_index, block, idx, op_role
......@@ -1535,7 +1611,7 @@ class Resharder:
)
)
# in the same process group, it will use allgahther and slice op.
# In the same process group, it will use allgahther and slice op.
else:
# NOTE: It just supports even partition scene.
partition_index_list = []
......@@ -1599,21 +1675,37 @@ class Resharder:
if not serial
else dist_tensor.local_sizes(rank=process)
)
op_desc_seq[process] = (
[
AllGatherOpDesc(
group=group,
shape=allgather_shape,
is_bool=(source_tensor.dtype == paddle.bool),
),
ConcatOpDesc(
partition_index_list=all_partition_index_list
),
slice_op_desc,
# c_concat pass
if (
target_dims_mapping.count(-1)
== len(target_dims_mapping)
and source_dims_mapping[:-1].count(-1)
== len(source_dims_mapping[:-1])
and source_dims_mapping[-1] != -1
):
op_desc_seq[process] = [
AllGatherConcatOpDesc(
group=group, shape=allgather_shape
)
]
if len(group) > 1
else [slice_op_desc]
)
else:
op_desc_seq[process] = (
[
AllGatherOpDesc(
group=group,
shape=allgather_shape,
is_bool=(
source_tensor.dtype == paddle.bool
),
),
ConcatOpDesc(
partition_index_list=all_partition_index_list
),
slice_op_desc,
]
if len(group) > 1
else [slice_op_desc]
)
return op_desc_seq
......@@ -1850,27 +1942,41 @@ class Resharder:
)
idx = idx_list[0]
elif isinstance(op_desc, SliceOpDesc):
assert (
len(partition_tensor_list) == 1 or not partition_tensor_list
)
to_slice_tensor = (
partition_tensor_list[0][0]
if len(partition_tensor_list) == 1
else source_tensor
)
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = Inserter.insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'),
)
elif isinstance(op_desc, SliceOpDesc) or isinstance(
op_desc, AllGatherConcatOpDesc
):
target_tensor = None
if isinstance(op_desc, SliceOpDesc):
assert (
len(partition_tensor_list) == 1
or not partition_tensor_list
)
to_slice_tensor = (
partition_tensor_list[0][0]
if len(partition_tensor_list) == 1
else source_tensor
)
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = Inserter.insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'),
)
else:
target_tensor = Inserter.insert_c_concat_op(
block,
idx,
source_tensor,
op_desc.group,
reshard_op.attr('op_role'),
)
assert target_tensor is not None
process_mesh = dist_attr[0]
dims_mapping = dist_attr[1]
......
......@@ -304,6 +304,60 @@ class TestMLPReshard(unittest.TestCase):
# the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog))
def test_c_concat(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x, process_mesh, [None, "x"])
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(w, process_mesh, [None, None])
y = paddle.distributed.shard_op(
paddle.matmul, process_mesh, [[None, None], [None, None]]
)(x, w)
rank_id = 0
dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_context, rank_id)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program
)
dist_context.block_state.parse_forward_blocks(complete_train_program)
(
partitioned_main_prog,
partitioned_startup_prog,
partitioned_params_grads,
) = partitioner.partition(complete_train_program, startup_program, [])
# test estimator
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
cost_estimator = CostEstimator(train_program, cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context
)
# test cache
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context
)
assert global_cost.time >= 0
assert max_memory > 0
resharder = Resharder(
partitioned_main_prog,
partitioned_startup_prog,
rank_id,
dist_context,
partitioned_params_grads,
)
resharder.reshard()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册