未验证 提交 c38b0488 编写于 作者: C caozhou 提交者: GitHub

add reshard module (#35779)

* add reshard module

* fix conflict

* update reshard module

* update and add unitest

* update reshard module and unitest

* add more unitests
上级 7a724ddb
......@@ -19,5 +19,7 @@ from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401
from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
__all__ = []
......@@ -23,6 +23,7 @@ from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_distributed_attr
from .context import get_default_distributed_context
from .operators import find_best_compatible_distributed_operator_impl
from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
......@@ -597,3 +598,172 @@ def complete_annotation(program, dist_context=None):
dist_context.amend_distributed_attr_for_program()
return program
def complete_backward_annotation(auto_parallel_main_prog, dist_context):
"""Complete the annotation of vars and ops in the backward phase for parallel program."""
def _is_grad_var_name(name):
if "@GRAD" in name:
return True
return False
grad_start_idx = None
for idx, op in enumerate(auto_parallel_main_prog.global_block().ops):
for var_name in op.output_arg_names:
# TODO: use _is_loss_op to judge
if "@GRAD" in var_name and op.type == "fill_constant":
grad_start_idx = idx
break
assert grad_start_idx is not None, "No backward procedure found in this program."
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
for idx in range(grad_start_idx, len(ops)):
# complete the loss op
if idx == grad_start_idx:
grad_var = vars[ops[idx].output_arg_names[0]]
grad_var_name = grad_var.name
forward_var_name = grad_var_name[:grad_var_name.find("@GRAD")]
forward_var = vars[forward_var_name]
tensor_attr = TensorDistributedAttribute(grad_var, dist_context)
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_dims_mapping()
tensor_attr.set_dims_mapping(dims_mapping)
tensor_attr.set_process_mesh(process_mesh)
dist_context.set_tensor_distributed_attr_for_program(grad_var,
tensor_attr)
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
# in the data parallel mode, the loss op followed by scale op.
if ops[idx + 1].type == "scale" and grad_var_name in ops[idx + 1].input_arg_names \
and grad_var_name in ops[idx + 1].output_arg_names:
op_attr = OperatorDistributedAttribute(ops[idx + 1],
dist_context)
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx + 1],
op_attr)
continue
# complete the annotation of the optimizer op.
# TODO: use _is_optimizer_op to judge
if "Grad" in ops[idx].input_names and "Param" in ops[idx].input_names:
assert len(ops[idx].input(
"Param")) == 1, "Only support one-to-one now."
assert len(ops[idx].input(
"Grad")) == 1, "Only support one-to-one now."
var = vars[ops[idx].input("Param")[0]]
grad_var = vars[ops[idx].input("Grad")[0]]
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
var).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
var).get_dims_mapping()
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue
# complete the c_allreduce_sum op for gradient in the data parallel mode.
if ops[idx].type == "c_allreduce_sum" and ops[
idx].input_arg_names == ops[idx].output_arg_names:
grad_var = vars[ops[idx].output_arg_names[0]]
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
grad_var).get_process_mesh()
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue
# complete the annotation of grad op
grad_op = ops[idx]
for i, op in enumerate(ops[:grad_start_idx]):
match_op = None
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc,
set(),
[])
grad_op_input = []
for input_arg_name in grad_op.desc.input_arg_names():
if "@GRAD" in input_arg_name:
name = input_arg_name[:input_arg_name.find("@GRAD") + 5]
grad_op_input.append(name)
else:
grad_op_input.append(input_arg_name)
# like sum op: the count of grad op will larger than 1
if len(grad_op_desc_list) > 1:
for grad_op_desc in grad_op_desc_list:
if grad_op_input == grad_op_desc.input_arg_names() \
and grad_op.desc.type() == grad_op_desc.type():
match_op = op
break
elif len(grad_op_desc_list) == 1:
if grad_op_input == grad_op_desc_list[0].input_arg_names() \
and grad_op.desc.type() == grad_op_desc_list[0].type():
match_op = op
if match_op is not None:
op_attr = dist_context.get_op_distributed_attr_for_program(op)
grad_op_attr = OperatorDistributedAttribute(grad_op,
dist_context)
grad_op_attr.set_process_mesh(op_attr.get_process_mesh())
for var_name in grad_op.input_arg_names:
if "@GRAD" in var_name:
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
vars[var_name]).get_dims_mapping()
grad_op_attr.set_input_dims_mapping(var_name,
dims_mapping)
else:
dims_mapping = op_attr.get_input_dims_mapping(var_name)
grad_op_attr.set_input_dims_mapping(var_name,
dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
for var_name in grad_op.output_arg_names:
if "@GRAD" in var_name:
forward_var = vars[var_name[:var_name.find("@GRAD")]]
tensor_attr = TensorDistributedAttribute(vars[var_name],
dist_context)
process_mesh = grad_op_attr.get_process_mesh()
dims_mapping = grad_op_attr.get_input_dims_mapping(
forward_var.name)
tensor_attr.set_process_mesh(process_mesh)
tensor_attr.set_dims_mapping(dims_mapping)
dist_context.set_tensor_distributed_attr_for_program(
vars[var_name], tensor_attr)
break
# complete the annotation of sum op for multiple renamed grad var
if grad_op.type == "sum" and all(
map(_is_grad_var_name, grad_op.input_arg_names)):
assert len(grad_op.output_arg_names
) == 1, "The output count of sum op should be one."
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context)
for var_name in grad_op.input_arg_names:
if "@GRAD" in var_name:
forward_var = vars[var_name[:var_name.find("@GRAD")]]
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_dims_mapping()
grad_op_attr.set_input_dims_mapping(var_name, dims_mapping)
for var_name in grad_op.output_arg_names:
forward_var = vars[var_name[:var_name.find("@GRAD")]]
tensor_attr = TensorDistributedAttribute(vars[var_name],
dist_context)
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_dims_mapping()
tensor_attr.set_dims_mapping(dims_mapping)
tensor_attr.set_process_mesh(process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
vars[var_name], tensor_attr)
grad_op_attr.set_process_mesh(
dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh())
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
......@@ -59,6 +59,9 @@ class DistributedContext:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
elif self._process_mesh.ndim == 3:
self._data_parallel_axis = 1
self._model_parallel_axis = 2
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
......
......@@ -146,8 +146,18 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format(
process_mesh_shape)
num_partition = process_mesh_shape[embedding_row_dim_mapping]
# TODO generalize here, support any mesh group
# TODO generalize here, support any mesh group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
if mesh_shape == 1:
if rank_id not in process_mesh_group:
assert len(
process_mesh.topology
) == 2, " row_parallel_embedding process mapping only support 2 dimensional process mesh, \
but got {}".format(len(process_mesh.topology))
rank_id = process_mesh_group[
process_mesh.process_group.index(rank_id) %
process_mesh_shape[0]]
relative_idx = process_mesh_group.index(rank_id)
else:
relative_idx = rank_id % num_partition
......@@ -156,8 +166,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
relative_idx = relative_idx * per_part_size
# TODO caculate ring id
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
......
......@@ -17,9 +17,10 @@ from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from .context import DistributedContext
from .context import get_default_distributed_context
from .completion import complete_annotation
from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner
from .process import get_all_process_groups
from .reshard import reshard
class AutoParallelizer:
......@@ -85,10 +86,16 @@ class AutoParallelizer:
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if rank not in process_group._ranks:
continue
process_group.instantiate()
# The last step: remove all distributed attributes to be compatiable
# with inference.
self._remove_distributed_attrs(partitioned_main_prog)
complete_backward_annotation(partitioned_main_prog, self._dist_context)
reshard(partitioned_main_prog, partitioned_startup_prog, rank,
self._dist_context)
return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from functools import reduce
import paddle
import paddle.fluid.core as core
from paddle.utils import unique_name
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Program, OpProtoHolder
import paddle.fluid.layers.utils as utils
from ..collective import _get_global_env
from .context import DistributedContext
from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process import new_process_group, ProcessGroup, PROCESS_GROUP_MAP
class AllGatherOpDesc:
"""
Describe the allgather op in the reshard phase.
Args:
group (list): Process group.
"""
def __init__(self, group):
self._group = group
self._desc = "all_gather"
@property
def group(self):
return self._group
@property
def desc(self):
return self._desc
def __repr__(self):
return f"op: {self._desc}, group: {self._group}."
class SendOpDesc:
"""
Describe the send op in the reshard phase.
Args:
partition_index (list): The index of partition in complete tensor.
dst (int): The destination process to receive.
"""
def __init__(self, partition_index, dst):
self._dst = dst
self._partition_index = partition_index
self._desc = "send"
@property
def partition_index(self):
return self._partition_index
@property
def dst(self):
return self._dst
@property
def desc(self):
return self._desc
def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}."
class RecvOpDesc:
"""
Describe the recv op in the reshard op.
Args:
partition_index (list): The index of partition in complete tensor.
src (int): The source process to send.
"""
def __init__(self, partition_index, src):
self._src = src
self._partition_index = partition_index
self._desc = "recv"
@property
def partition_index(self):
return self._partition_index
@property
def src(self):
return self._src
@property
def desc(self):
return self._desc
def __repr__(self):
return f"op: {self._desc}, partition_index: {self._partition_index}, src: {self._src}."
class SliceOpDesc:
"""
Describe the slice op in the reshard phase.
Args:
starts (list): It represents starting indices of corresponding axis in ``axes``.
ends (list): It represents ending indices of corresponding axis in ``axes``.
axes (list): Axes that `starts` and `ends` apply to .
"""
def __init__(self, starts, ends, axes):
self._starts = starts
self._ends = ends
self._axes = axes
self._desc = "slice"
@property
def starts(self):
return self._starts
@property
def ends(self):
return self._ends
@property
def axes(self):
return self._axes
@property
def desc(self):
return self._desc
def __repr__(self):
return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}."
class ConcatOpDesc:
"""
Describe the concat op in the reshard phase.
Args:
partition_index_list (list): A list contains all partition index.
"""
def __init__(self, partition_index_list):
self._partition_index_list = partition_index_list
self._desc = "concat"
@property
def partition_index_list(self):
return self._partition_index_list
@property
def desc(self):
return self._desc
def __repr__(self):
return f"op: {self._desc}, partition_index_list: {self._partition_index_list}."
def _compute_partition_shape(complete_shape, dims_mapping, process_shape):
"""Compute the shape of partition."""
partition_shape = []
for idx, item in enumerate(complete_shape):
if dims_mapping[idx] == -1:
partition_shape.append(item)
else:
partition_shape.append(item // process_shape[dims_mapping[idx]])
return partition_shape
def _compute_process_index(process, process_group, process_shape):
"""Compute the index of process_shape corresponding to the process."""
relative_process = process_group.index(process)
process_index = []
product = reduce(lambda x, y: x * y, process_shape)
for i in range(len(process_shape)):
idx = relative_process // (product // process_shape[i])
product = product // process_shape[i]
relative_process = relative_process - relative_process // product * product
process_index.append(idx)
return process_index
def _compute_partition_index(process, complete_shape, dims_mapping,
process_shape, process_group):
"""Compute the partition index in complete tensor."""
partition_shape = _compute_partition_shape(complete_shape, dims_mapping,
process_shape)
process_index = _compute_process_index(process, process_group,
process_shape)
partition_index = []
for i in range(len(complete_shape)):
if dims_mapping[i] == -1:
partition_index.append([0, partition_shape[i]])
else:
partition_index.append([
process_index[dims_mapping[i]] * partition_shape[i],
(process_index[dims_mapping[i]] + 1) * partition_shape[i]
])
return partition_index
def _compute_concat_info(partition_index_x, partition_index_y):
"""Judge whether two partition can be concatenated and compute concatenated partition index."""
differ_count = 0
concat_axis = -1
first_order = 0
new_partition = []
for idx, item in enumerate(partition_index_x):
if item != partition_index_y[idx]:
differ_count += 1
if item[1] == partition_index_y[idx][0] and item[
0] < partition_index_y[idx][1]:
concat_axis = idx
new_partition.append([item[0], partition_index_y[idx][1]])
elif item[0] == partition_index_y[idx][1] and item[
1] > partition_index_y[idx][0]:
first_order = 1
concat_axis = idx
new_partition.append([partition_index_y[idx][0], item[1]])
else:
new_partition.append(item)
if differ_count == 1:
return concat_axis, first_order, new_partition
else:
return -1, first_order, new_partition
def _concat_partitions(partition_index_list, partition_index):
"""Concat the given partitions without inserting concat op."""
if not partition_index_list:
partition_index_list.append(partition_index)
else:
i = 0
has_concat = False
while i < len(partition_index_list):
concat_axis, _, new_partition = _compute_concat_info(
partition_index_list[i], partition_index)
if concat_axis != -1:
has_concat = True
partition_index_list.pop(i)
_concat_partitions(partition_index_list, new_partition)
break
i += 1
if not has_concat:
partition_index_list.append(partition_index)
def _is_overlapped(shape_x, shape_y):
"""Judge whether two partitions intersect on the specified dimension."""
overlapped = False
if (shape_y[0] <= shape_x[0] < shape_y[1]) or (
shape_x[0] <= shape_y[0] < shape_x[1]):
overlapped = True
return overlapped
def _need_reshard(tensor_dist_attr, op_dist_attr):
"""Judge the tensor whether needs to be resharded."""
is_reshard = False
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_dist_attr.get_owner_tensor().name)
op_process_mesh = op_dist_attr.get_process_mesh()
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping,
op_process_mesh
])):
if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh._id != op_process_mesh._id:
is_reshard = True
return is_reshard
def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
"""compute the complete shape of the slice tensor with its process mesh and dims mapping"""
complete_shape = []
for idx, item in enumerate(slice_shape):
if dims_mapping[idx] == -1:
complete_shape.append(item)
else:
complete_shape.append(item * process_shape[dims_mapping[idx]])
return complete_shape
def find_op_desc_seq(source_tensor, tensor_dist_attr, op_dist_attr):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
source_tensor (Variable): A tensor with distributed attribute.
tensor_dist_attr (TensorDistributedAttribute): The distributed attribute of tensor.
op_dist_attr (OperatorDistributedAttribute): The distributed attribute of operator.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
process and value is a list containing op description.
"""
source_dims_mapping = tensor_dist_attr.get_dims_mapping()
source_process_mesh = tensor_dist_attr.get_process_mesh()
source_process_group = source_process_mesh.process_group
source_process_shape = source_process_mesh.topology
target_process_mesh = op_dist_attr.get_process_mesh()
target_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_dist_attr.get_owner_tensor().name)
target_process_group = target_process_mesh.process_group
target_process_shape = target_process_mesh.topology
complete_shape = _compute_complete_shape(
source_tensor.shape, source_process_shape, source_dims_mapping)
op_desc_seq = {}
# TODO: if the target process group has the same process with source process group
if set(target_process_group).intersection(set(
source_process_group)) and set(target_process_group).difference(
set(source_process_group)):
pass
# in the different process group, it will use send, recv, concat and slice op
elif target_process_group != source_process_group:
partition_process_mapping_list = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(source_process, complete_shape, source_dims_mapping, \
source_process_shape, source_process_group)
if not partition_process_mapping_list:
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
else:
partition_list = list(
[item[0] for item in partition_process_mapping_list])
process_list = list(
[item[1] for item in partition_process_mapping_list])
has_used = list(
[item[2] for item in partition_process_mapping_list])
if partition_list.count(source_partition_index) == 1:
index = partition_list.index(source_partition_index)
process_list[index].append(source_process)
has_used[index].append(False)
else:
partition_process_mapping_list.append(
[source_partition_index, [source_process], [False]])
for target_process in target_process_group:
has_sent = []
target_partition_index = _compute_partition_index(
target_process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
partition_index_list = []
all_partition_index_list = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
to_send_process = None
if all(_ for _ in list(map(_is_overlapped, source_partition_index, target_partition_index))) \
and source_partition_index not in has_sent:
idx = list([
item[0] for item in partition_process_mapping_list
]).index(source_partition_index)
has_used = list(
[item[2]
for item in partition_process_mapping_list])[idx]
process_list = list(
[item[1]
for item in partition_process_mapping_list])[idx]
i = 0
while i < len(has_used):
if not has_used[i]:
to_send_process = process_list[i]
has_used[i] = True
break
i += 1
if i == len(has_used):
has_used = list(map(lambda x: False, has_used))
to_send_process = process_list[0]
has_used[0] = True
assert to_send_process is not None, "Failed to find the send process."
if to_send_process not in op_desc_seq.keys():
op_desc_seq[to_send_process] = []
if target_process not in op_desc_seq.keys():
op_desc_seq[target_process] = []
all_partition_index_list.append(source_partition_index)
# append send and recv op desc
send_op_desc = SendOpDesc(source_partition_index,
target_process)
recv_op_desc = RecvOpDesc(source_partition_index,
to_send_process)
op_desc_seq[to_send_process].append(send_op_desc)
op_desc_seq[target_process].append(recv_op_desc)
has_sent.append(source_partition_index)
_concat_partitions(partition_index_list,
source_partition_index)
# append concat op desc
op_desc_seq[target_process].append(
ConcatOpDesc(all_partition_index_list))
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
concatenated_partition_index = partition_index_list[0]
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(target_partition_index[idx][0] - item[0])
slice_ends.append(target_partition_index[idx][1] - item[0])
slices_axes.append(idx)
op_desc_seq[target_process].append(
SliceOpDesc(slice_starts, slice_ends, slices_axes))
# in the same process group, it will use allgahther and slice op
else:
partition_index_list = []
all_partition_index_list = []
process_index = []
for source_process in source_process_group:
source_partition_index = _compute_partition_index(
source_process, complete_shape, source_dims_mapping,
source_process_shape, source_process_group)
if source_partition_index not in partition_index_list:
partition_index_list.append(source_partition_index)
process_index.append(
[[source_process, ], source_partition_index])
else:
process_index[partition_index_list.index(
source_partition_index)][0].append(source_process)
for i in range(len(process_index[0][0])):
group = []
for j in range(len(process_index)):
group.append(process_index[j][0][i])
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
target_partition_index = _compute_partition_index(
process, complete_shape, target_dims_mapping,
target_process_shape, target_process_group)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
slices_axes.append(idx)
slice_op_desc = SliceOpDesc(
starts=slice_starts, ends=slice_ends, axes=slices_axes)
op_desc_seq[process] = [AllGatherOpDesc(group=group),
ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \
if len(group) > 1 else [slice_op_desc]
return op_desc_seq
def _insert_send_op(block, idx, tensor, dst):
"""Insert send op into block at the given index."""
op_type = 'send_v2'
block._insert_op(
idx,
type=op_type,
inputs={'X': [tensor]},
attrs={
'ring_id': 0,
'peer': dst,
'use_calc_stream': True,
})
def _insert_recv_op(block, idx, tensor, src):
"""Insert recv op into block at the given index."""
op_type = 'recv_v2'
block._insert_op(
idx,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': 0,
'peer': src,
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': True,
})
def _insert_concat_op(block, idx, tensors, axis):
"""Insert concat op into block at the given block."""
inputs = {'X': tensors}
attrs = {}
attrs['axis'] = axis
helper = LayerHelper('concat', **locals())
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
block._insert_op(
idx, type='concat', inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name):
"""Insert slice op into block at the given block."""
inputs = {'Input': tensor}
infer_flags = list(1 for i in range(len(axes)))
attrs = {
"axes": axes,
"starts": starts,
"ends": ends,
"infer_flags": infer_flags
}
helper = LayerHelper('slice', **locals())
out = block.create_var(
name=new_var_name,
dtype=tensor.dtype,
type=core.VarDesc.VarType.LOD_TENSOR)
block._insert_op(
idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs)
return out
def _insert_split_op(block, idx, tensor, num_or_sections):
"""Insert split op into block at the given index."""
helper = LayerHelper('split', **locals())
input_shape = tensor.shape
inputs = {'X': tensor}
attrs = {'num': num_or_sections, "axis": 0}
with paddle.static.program_guard(block.program):
outs = [
helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) for i in range(num_or_sections)
]
block._insert_op(
idx, type="split", inputs=inputs, outputs={'Out': outs}, attrs=attrs)
return outs
def _insert_allgather_op(block, idx, tensor, ranks):
"""Insert allgather op into block at the given index."""
def _insert_fill_constant_op(block, idx):
"""Insert fill constant op into block at the given index."""
helper = LayerHelper("fill_constant", **locals())
with paddle.static.program_guard(block.program):
out = helper.create_variable_for_type_inference(dtype="int32")
inputs = {}
attrs = {'force_cpu': False}
attrs['str_value'] = str(int("1"))
attrs['value'] = int("1")
attrs['dtype'] = out.dtype
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant')
block._insert_op(
idx,
type='fill_constant',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True
return out
tensor_list = []
group = new_process_group(ranks)
idx_offset = 0
# instant process group before insert allgather op.
if not group.is_instantiate():
# insert fill_constant op
fill_constant_out = _insert_fill_constant_op(block, idx)
fill_constant_out.stop_gradient = True
# insert c_allreduce_sum op
block._insert_op(
idx + 1,
type="c_allreduce_sum",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'ring_id': 0,
'use_calc_stream': True})
# insert c_sync_calc_stream op
block._insert_op(
idx + 2,
type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]})
idx_offset = 3
# insert c_allgather op
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
with paddle.static.program_guard(block.program):
allgather_out = helper.create_variable_for_type_inference(
dtype=tensor.dtype)
block._insert_op(
idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [allgather_out]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'nranks': group._nranks
})
idx_offset += 1
# insert split op
split_out = _insert_split_op(block, idx + idx_offset, allgather_out,
group._nranks)
idx_offset += 1
tensor_list.extend(split_out)
return tensor_list, idx_offset
def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
block, idx):
"""Concat the tensors and insert concat op."""
if not partition_tensor_list:
partition_tensor_list.append((tensor, partition_index))
else:
i = 0
has_concat = False
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info(
partition_tensor_list[i][1], partition_index)
if concat_axis != -1:
has_concat = True
_ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis) \
if first_order == 0 else \
_insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis)
partition_tensor_list.pop(i)
idx[0] += 1
_concat_partitions_with_op(partition_tensor_list, _,
new_partition, block, idx)
break
i += 1
if not has_concat:
partition_tensor_list.append((tensor, partition_index))
def _init_comm_for_send_recv():
if not PROCESS_GROUP_MAP["global_group"].is_instantiate():
PROCESS_GROUP_MAP["global_group"].instantiate()
HAS_SENT = {}
HAS_RECV = {}
HAS_ALLGATHER = {}
def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
dist_context):
"""Parse op desc sequence and insert op in the block"""
global HAS_SENT
global HAS_RECV
global HAS_ALLGATHER
tensor_list = []
partition_tensor_list = []
if rank_id not in op_desc_seq.keys():
return
op_desc_list = op_desc_seq[rank_id]
block = program.global_block()
assert var_name in block.vars.keys(
), "The {} cannot be found in the {} program.".format(var_name, rank_id)
idx = None
for index, op in list(enumerate(block.ops)):
if op.desc.id == reshard_op.desc.id:
idx = index
break
assert idx is not None, "The op for reshard cannot be found in the rank {} program.".format(
rank_id)
matched_op = block.ops[idx]
source_tensor = block.vars[var_name]
for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc): # noqa: F401
if var_name not in HAS_ALLGATHER.keys():
HAS_ALLGATHER[var_name] = []
if not HAS_ALLGATHER[var_name] or op_desc.group not in list(
map(lambda x: x[0], HAS_ALLGATHER[var_name])):
tensor_list, idx_offset = _insert_allgather_op(
block, idx, source_tensor, op_desc.group)
idx += idx_offset
tensor_name_list = [var.name for var in tensor_list]
HAS_ALLGATHER[var_name].append(
[op_desc.group, tensor_name_list])
else:
for item in HAS_ALLGATHER[var_name]:
if op_desc.group == item[0]:
tensor_list = [
program.global_block().vars[var_name]
for var_name in item[1]
]
break
assert tensor_list, "The result of parsing allgather op should not be None."
elif isinstance(op_desc, SendOpDesc):
_init_comm_for_send_recv()
if var_name not in HAS_SENT.keys():
HAS_SENT[var_name] = []
if op_desc.dst not in HAS_SENT[var_name]:
_insert_send_op(block, idx, source_tensor, op_desc.dst)
idx += 1
HAS_SENT[var_name].append(op_desc.dst)
elif isinstance(op_desc, RecvOpDesc):
_init_comm_for_send_recv()
if var_name not in HAS_RECV.keys():
HAS_RECV[var_name] = {}
if op_desc.src not in HAS_RECV[var_name].keys():
partition_index = op_desc.partition_index
shape = []
for index in partition_index:
shape.append(index[1] - index[0])
recv_tensor = block.create_var(
name=unique_name.generate(var_name + "@recv"),
shape=shape,
dtype=source_tensor.dtype)
_insert_recv_op(block, idx, recv_tensor, op_desc.src)
tensor_list.append(recv_tensor)
idx += 1
HAS_RECV[var_name][op_desc.src] = recv_tensor
else:
tensor_list.append(HAS_RECV[var_name][op_desc.src])
elif isinstance(op_desc, ConcatOpDesc):
partition_index_list = op_desc.partition_index_list
idx_list = [idx]
for index, tensor in enumerate(tensor_list):
_concat_partitions_with_op(partition_tensor_list, tensor,
partition_index_list[index], block,
idx_list)
idx = idx_list[0]
elif isinstance(op_desc, SliceOpDesc):
assert len(partition_tensor_list) == 1 or not partition_tensor_list
to_slice_tensor = partition_tensor_list[0][0] if len(
partition_tensor_list) == 1 else source_tensor
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = _insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name)
tensor_attr = TensorDistributedAttribute(target_tensor,
dist_context)
process_mesh = dist_context.get_op_distributed_attr_for_program(
matched_op).get_process_mesh()
dims_mapping = dist_context.get_op_distributed_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.set_dims_mapping(dims_mapping)
tensor_attr.set_process_mesh(process_mesh)
dist_context.set_tensor_distributed_attr_for_program(target_tensor,
tensor_attr)
# rename op input name according to new name
for op in block.ops:
for name in op.input_arg_names:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(
op)
if name == var_name and op_dist_attr is not None:
op_process_mesh = op_dist_attr.get_process_mesh()
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
if op_process_mesh._id == process_mesh._id and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr._dims_mapping.pop(name, None)
def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
"""Remove no need ops in the main program"""
not_remove_op_ref = [
"create_py_reader", "create_double_buffer_reader", "read"
]
remove_op_idx = []
block = auto_parallel_main_prog.global_block()
ops = block.ops
vars = block.vars
for idx, op in enumerate(ops):
# handle read op in the pipeline scene specially, it will be removed in the future.
if op.type == "read":
dim_list = []
for var_name in op.output_arg_names:
dim_list.extend(vars[var_name].shape)
for i in range(idx, -1, -1):
if ops[i].type == "create_py_reader":
ops[i]._set_attr("shape_concat", dim_list)
break
continue
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
if op.type == "c_sync_comm_stream":
need_save = []
for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
vars[var_name]).get_process_mesh()
if rank_id in process_mesh.process_group:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
continue
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, need_save)
op.desc.set_output(proto.outputs[0].name, need_save)
continue
# judge the other op whether should be removed.
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
op_process_mesh = op_dist_attr.get_process_mesh()
if rank_id not in op_process_mesh.process_group and op.type not in not_remove_op_ref:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
block._remove_op(idx)
def _remove_no_need_vars(auto_parallel_main_prog):
"""Remove no need vars in the main program"""
remove_vars = set()
block = auto_parallel_main_prog.global_block()
ops = block.ops
vars = block.vars
need_vars = set()
for op in ops:
for var_name in op.input_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var_name in op.output_arg_names:
if var_name in vars:
need_vars.add(var_name)
for var in vars:
if var not in need_vars:
remove_vars.add(var)
for var in remove_vars:
block._remove_var(var)
def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id):
"""Remove no need vars and ops in the main program."""
_remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id)
_remove_no_need_vars(auto_parallel_main_prog)
def remove_no_need_in_startup(auto_parallel_main_prog,
auto_parallel_startup_prog):
"""Remove no need vars and ops in the startup program."""
main_input_vars = set()
main_ops = auto_parallel_main_prog.global_block().ops
for op in main_ops:
for var_name in op.input_arg_names:
main_input_vars.add(var_name)
startup_block = auto_parallel_startup_prog.global_block()
startup_output_vars = set()
startup_ops = startup_block.ops
for op in startup_ops:
# skip c_sync_comm_stream op
if op.type == "c_sync_comm_stream":
continue
for var_name in op.output_arg_names:
startup_output_vars.add(var_name)
need_vars = set()
for var_name in startup_output_vars:
if var_name in main_input_vars:
need_vars.add(var_name)
startup_ops = startup_block.ops
actual_need_vars = set()
for idx, op in enumerate(startup_ops):
is_need_op = False
if op.type == "c_sync_comm_stream":
continue
for var_name in op.output_arg_names:
if var_name in need_vars:
is_need_op = True
break
if is_need_op:
for var_name in op.output_arg_names:
actual_need_vars.add(var_name)
for var_name in op.input_arg_names:
actual_need_vars.add(var_name)
remove_vars = set()
for var_name in startup_block.vars:
if var_name not in actual_need_vars:
remove_vars.add(var_name)
for var in remove_vars:
startup_block._remove_var(var)
remove_op_idx = []
vars = startup_block.vars
for idx, op in enumerate(startup_block.ops):
is_no_need_op = False
if op.type == "c_sync_comm_stream":
var_names = []
for var_name in op.input_arg_names:
if var_name in vars:
var_names.append(var_name)
if not var_names:
remove_op_idx.append(idx)
else:
proto = OpProtoHolder.instance().get_op_proto(op.type)
op.desc.set_input(proto.inputs[0].name, var_names)
op.desc.set_output(proto.outputs[0].name, var_names)
continue
for var_name in op.output_arg_names:
if var_name not in vars:
is_no_need_op = True
break
if is_no_need_op:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
startup_block._remove_op(idx)
def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
dist_context):
"""
Reshard tensor in the program according to its dist attr and corresponding op dist attr.
Args:
auto_parallel_main_prog (Program): An auto parallel main program.
auto_parallel_startup_prog (Program): An auto parallel startup program.
rank_id (int): The process id.
"""
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \
"but got {}.".format(type(auto_parallel_main_prog))
assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \
"but got {}.".format(type(auto_parallel_startup_prog))
assert isinstance(rank_id, int), "The type of rank_id should be int, " \
"but got {}.".format(type(rank_id))
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
block = auto_parallel_main_prog.global_block()
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
idx_offset = 0
for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = block.vars[var_name]
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
var)
if tensor_dist_attr is not None and _need_reshard(
tensor_dist_attr, op_dist_attr):
reshard_op_desc = find_op_desc_seq(var, tensor_dist_attr,
op_dist_attr)
parse_op_desc(auto_parallel_main_prog, rank_id,
reshard_op_desc, var_name, op, dist_context)
cur_op_count = len(block.ops)
idx_offset = idx_offset + cur_op_count - pre_op_count
pre_op_count = cur_op_count
idx = idx + idx_offset + 1
else:
idx += 1
# remove no need vars and ops in the main program
remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id)
# remove no need vars and ops in the startip program
remove_no_need_in_startup(auto_parallel_main_prog,
auto_parallel_startup_prog)
......@@ -86,6 +86,10 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach()
......@@ -225,6 +229,10 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp)
elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
......@@ -589,6 +597,10 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS})
endif(NOT WIN32)
endif(NOT APPLE)
if(WITH_DGC)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
PP_MESH_0 = None
PP_MESH_1 = None
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1])
else:
auto.shard_tensor(
self.linear0.weight, _global_process_mesh,
dim_mapping=[-1, -1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh,
dim_mapping=[-1, -1])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1])
else:
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
return auto_parallel_main_prog, auto_parallel_startup_prog
def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check):
has_dist_attr = True
vars = dist_main_prog.global_block().vars
op_dist_attr = dist_context.get_op_distributed_attr_for_program(
op_need_check)
if not op_dist_attr or not op_dist_attr.get_process_mesh():
has_dist_attr = False
for var_name in op_need_check.input_arg_names:
if not op_dist_attr.get_input_dims_mapping(var_name) or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh():
has_dist_attr = False
break
if has_dist_attr:
for var_name in op_need_check.output_arg_names:
if not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh():
has_dist_attr = False
break
return has_dist_attr
def check_send_recv_result(dist_main_prog, rank_id):
send_result = False
recv_result = False
ops = dist_main_prog.global_block().ops
if rank_id == 0:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[
0]:
recv_result = True
else:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[
0]:
recv_result = True
return send_result and recv_result
def check_initialization(dist_startup_prog, rank_id):
if rank_id == 0:
need_check_params = [
"layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0",
"linear_0.b_0"
]
else:
need_check_params = ['linear_1.w_0', 'linear_1.b_0']
params = []
for var_name, var in dist_startup_prog.global_block().vars.items():
if var.is_parameter:
params.append(var_name)
return params == need_check_params
def check_initialization_for_dp(dist_startup_prog):
need_check_params = [
"layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0", "linear_0.b_0"
] + ['linear_1.w_0', 'linear_1.b_0']
params = []
for var_name, var in dist_startup_prog.global_block().vars.items():
if var.is_parameter:
params.append(var_name)
broadcast_varnames = []
for op in dist_startup_prog.global_block().ops:
if op.type == "c_broadcast":
broadcast_varnames.append(op.output_arg_names[0])
return params == need_check_params == broadcast_varnames
class TestMLPReshard(unittest.TestCase):
def test_complete_backward_annotation(self):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, 0)
complete_backward_annotation(dist_main_prog, dist_context)
op_need_check = None
for op in dist_main_prog.global_block().ops:
if op.type == "gelu_grad":
op_need_check = op
break
# grad op should have dist attr
self.assertTrue(
check_backward_dist_attr(dist_context, dist_main_prog,
op_need_check))
def test_mlp_pp(self):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH)
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1], parent=ROOT_MESH)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
complete_backward_annotation(dist_main_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
# parameter initialization of every rank should be different in the pipeline scene
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_dp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
complete_backward_annotation(dist_main_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
# all parameters should be initialized in dp scene
self.assertTrue(check_initialization_for_dp(dist_startup_prog))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
_global_process_mesh = auto.ProcessMesh(
[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH)
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
return auto_parallel_main_prog, auto_parallel_startup_prog
def check_send_recv_result(dist_main_prog, rank_id):
send_result = False
recv_result = False
ops = dist_main_prog.global_block().ops
if rank_id in [0, 1, 4, 5]:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[
0]:
recv_result = True
else:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[
0]:
recv_result = True
return send_result and recv_result
def check_initialization_for_dpmppp(dist_startup_prog):
broadcast_varnames = []
for op in dist_startup_prog.global_block().ops:
if op.type == "c_broadcast":
broadcast_varnames.append(op.output_arg_names[0])
result = len(broadcast_varnames) > 0
return result
class TestMLPReshard(unittest.TestCase):
def test_mlp_dpmppp(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
print(dist_main_prog)
complete_backward_annotation(dist_main_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
print(dist_main_prog)
print(dist_startup_prog)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
# check parameter initialization
self.assertTrue(check_initialization_for_dpmppp(dist_startup_prog))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard
paddle.enable_static()
_global_parallel_strategy = "mp_pp"
ROOT_MESH = auto.ProcessMesh([[0, 1], [2, 3]])
_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([0, 1], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([2, 3], parent=ROOT_MESH)
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.word_embeddings = nn.Embedding(
hidden_size,
hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear2 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
def forward(self, input):
auto.shard_tensor(
self.word_embeddings.weight, PP_MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 0])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[0, -1])
auto.shard_tensor(self.linear2.weight, PP_MESH_1, dim_mapping=[0, -1])
w_out = self.word_embeddings(input)
out = self.linear0(w_out)
gelu_out = F.gelu(out, approximate=True)
out = self.linear1(gelu_out)
out1 = self.linear2(gelu_out)
out = out + out1
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(name="input", shape=[batch_size], dtype='int32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
return auto_parallel_main_prog, auto_parallel_startup_prog
def check_send_recv_result(dist_main_prog, rank_id):
send_result = False
recv_result = False
ops = dist_main_prog.global_block().ops
if rank_id in [0, 1]:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[
0]:
recv_result = True
else:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[
0]:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[
0]:
recv_result = True
return send_result and recv_result
def check_initialization_for_mppp(dist_startup_prog, rank_id):
if rank_id in [0, 1]:
need_check_params = []
else:
need_check_params = ["linear_1.b_0", "linear_2.b_0"]
broadcast_varnames = []
for op in dist_startup_prog.global_block().ops:
if op.type == "c_broadcast":
broadcast_varnames.append(op.output_arg_names[0])
return need_check_params == broadcast_varnames
def check_allgather(dist_main_program):
allgather_out = "x@RESHARD_0"
var_result = False
op_result = False
vars = dist_main_program.global_block().vars
if allgather_out in vars and vars[allgather_out].shape == (4, 4):
var_result = True
for op in dist_main_program.global_block().ops:
if op.type == "matmul_v2":
if allgather_out in op.input_arg_names:
op_result = True
return var_result and op_result
class TestMLPReshard(unittest.TestCase):
def test_mlp_mppp(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
complete_backward_annotation(dist_main_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
# parameter which not been sliced should be the same in the mp scene
self.assertTrue(
check_initialization_for_mppp(dist_startup_prog, rank_id))
def test_allgather(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 3], parent=ROOT_MESH)
with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x, process_mesh, dim_mapping=[0, -1])
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(w, process_mesh, dim_mapping=[-1, -1])
y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
x.name: [-1, -1],
w.name: [-1, -1]
}, **{"x": x,
"y": w})[0]
rank_id = 0
dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
reshard(auto_parallel_main_prog, startup_program, rank_id, dist_context)
# the x should not be slice
self.assertTrue(check_allgather(auto_parallel_main_prog))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
if os.getenv("CUDA_VISIBLE_DEVICES", None) is None:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import get_default_distributed_context
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process import new_process_group
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1])
else:
auto.shard_tensor(
self.linear0.weight, _global_process_mesh,
dim_mapping=[-1, -1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh,
dim_mapping=[-1, -1])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1])
else:
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_dist_prog_with_parallelizer(train_program, startup_program,
dist_context):
global _global_process_mesh
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
# fake a comm group
pg = new_process_group([3, 4])
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, startup_program)
return distributed_main_program, distributed_startup_program
def check_send_recv_result(dist_main_prog, rank_id):
send_result = False
recv_result = False
ops = dist_main_prog.global_block().ops
if rank_id == 0:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[
0]:
recv_result = True
else:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[
0]:
recv_result = True
return send_result and recv_result
class TestMLPReshard(unittest.TestCase):
def test_mlp_serial(self):
global _global_parallel_strategy
_global_parallel_strategy = None
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = get_default_distributed_context()
rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog_with_parallelizer(
train_program, startup_program, dist_context)
# send and recv should not exist in serial scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册