未验证 提交 12155358 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Integrate all modules (#35483)

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* add dist

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update

* update

* delete unused proto

* resotre op_desc

* restore type_defs

* update var_desc

* remove dimss_mapping for proto_pybind

* update interface.py

* update framework.py

* update

* update

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* [WIP] Add the auto completion feature and related codes

* [WIP] Improve the auto completion and related codes

* [WIP] Make the auto completion to support data-parallel

* [WIP] Make the completion support mp and dp+mp

* [WIP] Refactor auto completion unit test for MLP

* [WIP] Refactor the implementation of DistributedOperatorImpl

* [WIP] Improve dims_mapping update rule and fix a bug

* [WIP] Support auto completion for one transformer decoder layer

* [WIP] Add a minor change

* [WIP] Fix a bug within the uint test

* Shard XShape tensor, add embedding completion and refactor code

* Add the distributed_operators dir to setup.py.in

* Improve the completion process and add the unittest for gpt

* fix process_mesh ut

* fix process_mesh ut

* update

* update, test=develop

* Add support for automatically completing distributed attrs of special ops

* update

* update

* update

* fix doc sample codes, test=develop

* improve coverage, test=develop

* add static_mode check, test=develop

* Model the cluster for cost model and physical mapping

* update, test=develop

* add set_placement, test=develop

* Add the check to make sure the candidate tensors' size is great than zero

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update, test=develop

* Auto mark dist attrs annotated by user

* update ndarray to nested list, test=develop

* update, test=develop

* Add auto-completion module for auto-parallel (based on PR#33804)

* Remove unnecessary files

* Remove unrelated files for the auto completion pr

* Update the unit test to improve the coverage

* Modify codes based on reviews

* Minor changes for CI

* Improve some codes based on new comments

* Fix bugs caused by shallow copy in attributes.py
* Imporve amend_distributed_attr_for_program in context.py
* Other changes for weihang's comments

* support shard reader

* support shard reader

* add parallel mode

* update process mesh

* add method to compute comm_group

* implement dist_embedding forward func

* implement dist matmul forward func

* implement dist reshape forward func

* add transpiler framework

* add transpiler forward

* implement transpiler forward

* implement transpiler backward & update

* add process

* add unitest

* chmod

* chmod

* chmod

* update unitest

* add unitest for gpt

* remove unused print

* rename transpiler --> partitioner

* rename transpiler --> partitioner

* chmod

* chmod

* bug fixed

* remove amp function

* update case for dp mode

* update case for dp mode

* [Auto Parallel] Integrate all parts with the newest code

* Integrate all parts of auto parallel and improve codes

* Integrate all parts by AutoParallelizer
* Add unit test for AutoParallelizer
* Improve auto completion module for pipeline parallel
* Add support for matmul_v2 in dist_matmul
* Correct the typo "stratergy" to "strategy"

* Modify distributed_strategy.proto to conform the main stream

* Restore parts of distributed_strategy to conform the develop branch
Co-authored-by: Nsandyhouse <lilong12@baidu.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
上级 db5fd2a1
...@@ -202,6 +202,7 @@ message DistributedStrategy { ...@@ -202,6 +202,7 @@ message DistributedStrategy {
optional bool calc_comm_same_stream = 32 [ default = false ]; optional bool calc_comm_same_stream = 32 [ default = false ];
optional bool asp = 33 [ default = false ]; optional bool asp = 33 [ default = false ];
optional bool fuse_grad_merge = 34 [ default = false ]; optional bool fuse_grad_merge = 34 [ default = false ];
optional bool semi_auto = 35 [ default = false ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
......
...@@ -253,6 +253,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -253,6 +253,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
if (not tensor_node.is_var()) or (tensor_node.var() is None): if (not tensor_node.is_var()) or (tensor_node.var() is None):
return False return False
tensor_desc = tensor_node.var() tensor_desc = tensor_node.var()
# Skip reader tensor
if tensor_desc.type() == core.VarDesc.VarType.READER:
return False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node) tensor_node)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
...@@ -263,6 +266,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -263,6 +266,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list = [] dims_mapping_list = []
for pred_op_node in tensor_node.inputs: for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None: if pred_op_node.op() is not None:
if pred_op_node.op().type() == "create_py_reader" \
or pred_op_node.op().type() == "create_double_buffer_reader" \
or pred_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
pred_op_node) pred_op_node)
op_dims_mapping = op_dist_attr.get_output_dims_mapping( op_dims_mapping = op_dist_attr.get_output_dims_mapping(
...@@ -279,6 +286,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -279,6 +286,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list = [] dims_mapping_list = []
for succ_op_node in tensor_node.outputs: for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None: if succ_op_node.op() is not None:
if succ_op_node.op().type() == "create_py_reader" \
or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
succ_op_node) succ_op_node)
op_dims_mapping = op_dist_attr.get_input_dims_mapping( op_dims_mapping = op_dist_attr.get_input_dims_mapping(
...@@ -298,11 +309,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -298,11 +309,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = False changed = False
if (not op_node.is_op()) or (op_node.op() is None): if (not op_node.is_op()) or (op_node.op() is None):
return False return False
# Skip reader op
op_desc = op_node.op() op_desc = op_node.op()
if op_desc.type() == "create_py_reader" \
or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "read":
return False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
if fwd: if fwd:
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var() tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_input_dims_mapping( if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()): tensor_desc.name()):
...@@ -344,6 +362,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -344,6 +362,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
else: else:
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var() tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_output_dims_mapping( if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()): tensor_desc.name()):
...@@ -402,7 +422,6 @@ def complete_annotation(program, dist_context=None): ...@@ -402,7 +422,6 @@ def complete_annotation(program, dist_context=None):
# Initialize distributed attributes for all var and op node in program # Initialize distributed attributes for all var and op node in program
dist_context.initialize_distributed_attr_for_program(program) dist_context.initialize_distributed_attr_for_program(program)
# print_program_with_distributed_attr(program, dist_context)
# Convert program to graph # Convert program to graph
graph = framework.IrGraph(core.Graph(program.desc)) graph = framework.IrGraph(core.Graph(program.desc))
...@@ -410,10 +429,30 @@ def complete_annotation(program, dist_context=None): ...@@ -410,10 +429,30 @@ def complete_annotation(program, dist_context=None):
# Initialize distributed attributes for all var and op node in graph # Initialize distributed attributes for all var and op node in graph
dist_context.initialize_distributed_attr_for_graph(graph) dist_context.initialize_distributed_attr_for_graph(graph)
# # Complete process mesh for each node # Complete process mesh for each node
all_nodes = list(graph.all_nodes()) all_nodes = list(graph.all_nodes())
def sort_key_fun(node):
first = -1
if node.is_op():
first = 0
else:
first = 1
second = -1
if node.is_op() and node.op() is not None:
second = node.op().id()
if node.is_var() and node.var() is not None:
second = node.var().id()
return (first, second)
all_nodes.sort(key=sort_key_fun)
reach_fix_point = False reach_fix_point = False
while not reach_fix_point: while not reach_fix_point:
total_changed = False
reach_fwd_fix_point = False
reach_bwd_fix_point = False
while not reach_fwd_fix_point:
changed = False changed = False
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
...@@ -426,7 +465,14 @@ def complete_annotation(program, dist_context=None): ...@@ -426,7 +465,14 @@ def complete_annotation(program, dist_context=None):
dist_context, node, fwd=True) dist_context, node, fwd=True)
if op_changed: if op_changed:
changed = True changed = True
for node in reversed(all_nodes): if changed:
reach_fwd_fix_point = False
total_changed = True
else:
reach_fwd_fix_point = True
while not reach_bwd_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh( tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=False) dist_context, node, fwd=False)
...@@ -438,9 +484,79 @@ def complete_annotation(program, dist_context=None): ...@@ -438,9 +484,79 @@ def complete_annotation(program, dist_context=None):
if op_changed: if op_changed:
changed = True changed = True
if changed: if changed:
reach_bwd_fix_point = False
total_changed = True
else:
reach_bwd_fix_point = True
if total_changed:
reach_fix_point = False reach_fix_point = False
else: else:
reach_fix_point = True reach_fix_point = True
# Validation the completion of process meshes and should be moved to a proper location
is_wrong = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
node)
if tensor_dist_attr.get_process_mesh() is None:
msg_str = ""
for op_node in node.inputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
for op_node in node.outputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format(
node.var().name(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
node)
if op_dist_attr.get_process_mesh() is None:
msg_str = ""
for tensor_node in node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
for tensor_node in node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.get_process_mesh())
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format(
node.op().type(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is None:
print("op op is None", node.name())
if is_wrong:
assert False, "Cannot complete process_meshes of the program."
# Complete dims_mapping for each node # Complete dims_mapping for each node
reach_fix_point = False reach_fix_point = False
......
...@@ -142,12 +142,15 @@ class DistributedContext: ...@@ -142,12 +142,15 @@ class DistributedContext:
tensor.desc, tensor_dist_attr) tensor.desc, tensor_dist_attr)
self.set_tensor_distributed_attr_for_program( self.set_tensor_distributed_attr_for_program(
tensor, tensor_dist_attr) tensor, tensor_dist_attr)
if tensor.type == core.VarDesc.VarType.READER:
tensor_dist_attr.set_shape([])
else:
tensor_dist_attr.set_shape(tensor.desc.shape()) tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor_dist_attr.get_process_mesh() is not None: if tensor_dist_attr.get_process_mesh() is not None:
tensor_dist_attr.mark_as_annotated("process_mesh") tensor_dist_attr.mark_as_annotated("process_mesh")
if tensor_dist_attr.get_dims_mapping() is None: if tensor_dist_attr.get_dims_mapping() is None:
tensor_dims_mapping = [ tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape())) -1 for _ in range(len(tensor_dist_attr.get_shape()))
] ]
tensor_dist_attr.set_dims_mapping(tensor_dims_mapping) tensor_dist_attr.set_dims_mapping(tensor_dims_mapping)
else: else:
...@@ -168,12 +171,18 @@ class DistributedContext: ...@@ -168,12 +171,18 @@ class DistributedContext:
op_dist_attr.mark_as_annotated("process_mesh") op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names: for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name # There may be a better way to find the tensor by name
if op.type == "create_py_reader" \
or tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_input_shape(tensor_name, [])
else:
tensor = op.block._var_recursive(tensor_name) tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name, op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape()) tensor.desc.shape())
if op_dist_attr.get_input_dims_mapping(tensor_name) is None: if op_dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [ tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape())) -1
for _ in range(
len(op_dist_attr.get_input_shape(tensor_name)))
] ]
op_dist_attr.set_input_dims_mapping(tensor_name, op_dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping) tensor_dims_mapping)
...@@ -184,12 +193,18 @@ class DistributedContext: ...@@ -184,12 +193,18 @@ class DistributedContext:
op_dist_attr.mark_as_parameter(tensor_name) op_dist_attr.mark_as_parameter(tensor_name)
for tensor_name in op.output_arg_names: for tensor_name in op.output_arg_names:
tensor = op.block._var_recursive(tensor_name) tensor = op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_output_shape(tensor_name, [])
else:
op_dist_attr.set_output_shape(tensor_name, op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape()) tensor.desc.shape())
if op_dist_attr.get_output_dims_mapping( if op_dist_attr.get_output_dims_mapping(
tensor_name) is None: tensor_name) is None:
tensor_dims_mapping = [ tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape())) -1
for _ in range(
len(
op_dist_attr.get_output_shape(tensor_name)))
] ]
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping) tensor_name, tensor_dims_mapping)
...@@ -378,8 +393,8 @@ class DistributedContext: ...@@ -378,8 +393,8 @@ class DistributedContext:
# If the dimension of tensor is less than the sharding dimension of process mesh, # If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?) # we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[dims_mapping[ if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
for attr in self._op_distributed_attr_map_for_program.values(): for attr in self._op_distributed_attr_map_for_program.values():
...@@ -392,8 +407,8 @@ class DistributedContext: ...@@ -392,8 +407,8 @@ class DistributedContext:
# If the dimension of tensor is less than the sharding dimension of process mesh, # If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?) # we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[ if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
dims_mapping[i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
for arg_name in attr.get_owner_op().desc.output_arg_names(): for arg_name in attr.get_owner_op().desc.output_arg_names():
...@@ -403,8 +418,8 @@ class DistributedContext: ...@@ -403,8 +418,8 @@ class DistributedContext:
# If the dimension of tensor is less than the sharding dimension of process mesh, # If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?) # we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[ if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
dims_mapping[i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
def _get_data_parallel_info(self): def _get_data_parallel_info(self):
......
...@@ -462,10 +462,271 @@ class DistributedMatmulV2(DistributedOperator): ...@@ -462,10 +462,271 @@ class DistributedMatmulV2(DistributedOperator):
register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2")) register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2"))
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl0, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
from ..process import new_process_group
from ..transpiler import _get_comm_group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.topology,
model_parallel_axis,
process_mesh.process_group, rank_id)
group = new_process_group(group_ranks)
# print("@@@@@@@@@@@@@@@@@@@@@ 5", group)
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {'trans_x': False, 'trans_y': False}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl1, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
from ..process import new_process_group
from ..transpiler import _get_comm_group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.topology,
model_parallel_axis,
process_mesh.process_group, rank_id)
group = new_process_group(group_ranks)
# print("@@@@@@@@@@@@@@@@@@@@@ 4", group)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear')
check_dtype(X_var.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {'trans_x': False, 'trans_y': False}
inputs = {'X': X_var, 'Y': Weight_var}
intermediate_var_0 = dst_block.create_var(
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
})
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# ReplicateParallel # ReplicateParallel
class DistributedMatmulV2Impl(DistributedOperatorImpl): class DistributedMatmulV2Impl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulV2Impl, self).__init__() super(DistributedMatmulV2Impl2, self).__init__()
self._name = name self._name = name
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
...@@ -514,5 +775,9 @@ class DistributedMatmulV2Impl(DistributedOperatorImpl): ...@@ -514,5 +775,9 @@ class DistributedMatmulV2Impl(DistributedOperatorImpl):
return changed return changed
register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl0("column_parallel"))
register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl1("row_parallel"))
register_distributed_operator_impl( register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl("replicate_parallel")) "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.distributed.fleet import cloud_utils
from .context import DistributedContext
from .context import get_default_distributed_context
from .completion import complete_annotation
from .partitioner import Partitioner
from .process import get_all_process_groups
class AutoParallelizer:
"""
AutoParallelizer is the main controller class to do the auto parallel process.
And the auto parallel process will be triggered in the wrapped parallelize function.
To facilitate the auto parallelization, it will contain information about program, cluster and the
related context. In this basic version, the program information will be retrevied from
Fleet object, and the cluster information can be retrevied in the new created Cluster object,
and the context information can be retrevied in the new created DistributedContext.
"""
def __init__(self, fleet):
self._fleet = fleet
self._optimizer = self._fleet.user_defined_optimizer
self._dist_strategy = self._fleet._user_defined_strategy
# self._dist_context = DistributedContext()
self._dist_context = get_default_distributed_context()
def parallelize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._original_main_program = loss.block.program
# For now, we only allow user to use the default startup and main program
assert startup_program is not None
if startup_program == None:
self._original_startup_program = \
paddle.static.default_startup_program().clone(for_test=False)
startup_program = paddle.static.default_startup_program()
else:
self._original_startup_program = \
startup_program.clone(for_test=False)
# Annotation completion
completed_main_program = complete_annotation(
self._original_main_program, self._dist_context)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
partitioned_main_prog, partitioned_startup_prog = partitioner.transpile_forward(
completed_main_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, completed_main_program, startup_program,
partitioned_main_prog, partitioned_startup_prog)
dist_optimize_ops = partitioner.apply_optimize(
self._optimizer, dist_params_grads, partitioned_main_prog,
partitioned_startup_prog)
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
process_group.instantiate()
return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog
...@@ -561,7 +561,7 @@ class Partitioner(object): ...@@ -561,7 +561,7 @@ class Partitioner(object):
if not var_dist_attr.is_parameter(): if not var_dist_attr.is_parameter():
mapping = var_dist_attr.get_dims_mapping() mapping = var_dist_attr.get_dims_mapping()
mesh = var_dist_attr.get_process_mesh().topology mesh = var_dist_attr.get_process_mesh().topology
if mapping[0] >= 0 and mesh[mapping[0]] > 1: if mapping and mapping[0] >= 0 and mesh[mapping[0]] > 1:
self._enable_data_parallel = True self._enable_data_parallel = True
break break
......
...@@ -79,11 +79,10 @@ def compute_compatible_process_mesh(process_mesh_list): ...@@ -79,11 +79,10 @@ def compute_compatible_process_mesh(process_mesh_list):
return compatible_process_mesh return compatible_process_mesh
for process_mesh in process_mesh_list: for process_mesh in process_mesh_list:
if process_mesh is not None: if process_mesh is not None:
if compatible_process_mesh is None: if compatible_process_mesh is None or compatible_process_mesh == process_mesh:
compatible_process_mesh = process_mesh compatible_process_mesh = process_mesh
else: else:
assert process_mesh == compatible_process_mesh, \ return None
"There is no compatible process mesh."
return compatible_process_mesh return compatible_process_mesh
......
...@@ -1596,6 +1596,41 @@ class DistributedStrategy(object): ...@@ -1596,6 +1596,41 @@ class DistributedStrategy(object):
else: else:
print("WARNING: auto should have value of bool type") print("WARNING: auto should have value of bool type")
@property
def semi_auto(self):
"""
Indicating whether we are using semi-auto parallel function
This feature is currently an experimental feature. Currently,
auto-parallelism can be used only when a user does not set any other
strategy configs except semi-auto. For details, please reference the following
code example
Default Value: False
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.semi_auto = True
# if set other strategy at the same time, auto will not apply
# strategy.amp = True
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
"""
return self.strategy.semi_auto
@semi_auto.setter
def semi_auto(self, flag):
if isinstance(flag, bool):
self.strategy.semi_auto = flag
else:
print("WARNING: semi-auto should have value of bool type")
@property @property
def cudnn_exhaustive_search(self): def cudnn_exhaustive_search(self):
""" """
......
...@@ -1408,6 +1408,14 @@ class Fleet(object): ...@@ -1408,6 +1408,14 @@ class Fleet(object):
context["origin_startup_program"] = startup_program context["origin_startup_program"] = startup_program
context["role_maker"] = self._role_maker context["role_maker"] = self._role_maker
# Use the auto-parallel's routines instead
if self._user_defined_strategy.semi_auto:
from ...auto_parallel.parallelizer import AutoParallelizer
auto_parallelizer = AutoParallelizer(self)
optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
loss, startup_program, parameter_list, no_grad_set)
return optimize_ops, params_grads, dist_startup_prog, dist_main_prog
# compile time # compile time
distributed_optimizer_list = \ distributed_optimizer_list = \
MetaOptimizerFactory()._get_valid_meta_optimizers( MetaOptimizerFactory()._get_valid_meta_optimizers(
......
...@@ -33,8 +33,9 @@ from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffi ...@@ -33,8 +33,9 @@ from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffi
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context from paddle.distributed.auto_parallel.context import set_default_distributed_context
paddle.enable_static() paddle.enable_static()
_global_parallel_stratergy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
_global_process_mesh2 = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
...@@ -59,16 +60,22 @@ class MLPLayer(nn.Layer): ...@@ -59,16 +60,22 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
elif _global_parallel_strategy == "pp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh2,
dim_mapping=[1, -1])
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -90,10 +97,10 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -90,10 +97,10 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, hidden_size], shape=[batch_size, sequence_len, hidden_size],
dtype='float32') dtype='float32')
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
...@@ -108,8 +115,8 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -108,8 +115,8 @@ def mlp_pretrain_forward(train_program, start_program):
class TestMLPAutoCompletion(unittest.TestCase): class TestMLPAutoCompletion(unittest.TestCase):
def test_mlp_dp(self): def test_mlp_dp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -127,8 +134,8 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -127,8 +134,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_mlp_mp(self): def test_mlp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -147,8 +154,8 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -147,8 +154,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_mlp_dp_mp(self): def test_mlp_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
...@@ -167,19 +174,26 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -167,19 +174,26 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_mlp_misc(self): def test_mlp_misc(self):
global _global_parallel_stratergy # import pdb
_global_parallel_stratergy = "dp_mp" global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1], [2, 3]], parent=ROOT_MESH)
global _global_process_mesh2
_global_process_mesh2 = auto.ProcessMesh(
mesh=[[4, 5], [6, 7]], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
# pdb.set_trace()
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
dist_context.finalize_distributed_attr_for_program( dist_context.finalize_distributed_attr_for_program(
complete_train_program) complete_train_program)
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
...@@ -246,10 +260,10 @@ class AttentionLayer(nn.Layer): ...@@ -246,10 +260,10 @@ class AttentionLayer(nn.Layer):
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
def forward(self, input): def forward(self, input):
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
...@@ -260,14 +274,14 @@ class AttentionLayer(nn.Layer): ...@@ -260,14 +274,14 @@ class AttentionLayer(nn.Layer):
k = self.k_proj(input) k = self.k_proj(input)
v = self.v_proj(input) v = self.v_proj(input)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
...@@ -304,11 +318,11 @@ class AttentionLayer(nn.Layer): ...@@ -304,11 +318,11 @@ class AttentionLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1]) dim_mapping=[1, -1])
...@@ -340,8 +354,8 @@ def attn_pretrain_forward(train_program, start_program): ...@@ -340,8 +354,8 @@ def attn_pretrain_forward(train_program, start_program):
class TestAttentionAutoCompletion(unittest.TestCase): class TestAttentionAutoCompletion(unittest.TestCase):
def test_attn_dp(self): def test_attn_dp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -359,8 +373,8 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -359,8 +373,8 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_attn_mp(self): def test_attn_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -379,8 +393,8 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -379,8 +393,8 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_attn_dp_mp(self): def test_attn_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
...@@ -463,28 +477,29 @@ class DecoderLayer(nn.Layer): ...@@ -463,28 +477,29 @@ class DecoderLayer(nn.Layer):
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear( self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(self.dropout_ratio) self.dropout1 = nn.Dropout(self.dropout_ratio)
self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids): def forward(self, input_ids, position_ids):
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_embeddings = self.word_embeddings(input_ids) input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
...@@ -494,7 +509,7 @@ class DecoderLayer(nn.Layer): ...@@ -494,7 +509,7 @@ class DecoderLayer(nn.Layer):
embeddings = self.dropout1(embeddings) embeddings = self.dropout1(embeddings)
# Pre-norm # Pre-norm
target = self.norm(embeddings) target = self.norm1(embeddings)
# The following is the attention part # The following is the attention part
q = self.q_proj(target) q = self.q_proj(target)
...@@ -504,14 +519,14 @@ class DecoderLayer(nn.Layer): ...@@ -504,14 +519,14 @@ class DecoderLayer(nn.Layer):
k = self.k_proj(target) k = self.k_proj(target)
v = self.v_proj(target) v = self.v_proj(target)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
...@@ -549,11 +564,11 @@ class DecoderLayer(nn.Layer): ...@@ -549,11 +564,11 @@ class DecoderLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1]) dim_mapping=[1, -1])
...@@ -562,19 +577,19 @@ class DecoderLayer(nn.Layer): ...@@ -562,19 +577,19 @@ class DecoderLayer(nn.Layer):
residual = embeddings + self.dropout2(out) residual = embeddings + self.dropout2(out)
# Pre-norm # Pre-norm
out0 = self.norm(residual) out0 = self.norm2(residual)
# The following is the MLP part # The following is the MLP part
out1 = self.linear0(out0) out1 = self.linear0(out0)
out2 = F.gelu(out1, approximate=True) out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2) out3 = self.linear1(out2)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
...@@ -613,8 +628,8 @@ def decoder_pretrain_forward(train_program, start_program): ...@@ -613,8 +628,8 @@ def decoder_pretrain_forward(train_program, start_program):
class TestDecoderLayerAutoCompletion(unittest.TestCase): class TestDecoderLayerAutoCompletion(unittest.TestCase):
def test_decoder_dp(self): def test_decoder_dp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -632,8 +647,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -632,8 +647,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_decoder_mp(self): def test_decoder_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -652,8 +667,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -652,8 +667,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_decoder_dp_mp(self): def test_decoder_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
......
...@@ -36,7 +36,7 @@ from paddle.distributed.auto_parallel.utils import print_program_with_distribute ...@@ -36,7 +36,7 @@ from paddle.distributed.auto_parallel.utils import print_program_with_distribute
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.context import DistributedContext
paddle.enable_static() paddle.enable_static()
_global_parallel_stratergy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
...@@ -106,10 +106,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -106,10 +106,10 @@ class MultiHeadAttention(nn.Layer):
""" """
q = self.q_proj(query) q = self.q_proj(query)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
...@@ -143,19 +143,19 @@ class MultiHeadAttention(nn.Layer): ...@@ -143,19 +143,19 @@ class MultiHeadAttention(nn.Layer):
""" """
k = self.k_proj(key) k = self.k_proj(key)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
...@@ -236,11 +236,11 @@ class MultiHeadAttention(nn.Layer): ...@@ -236,11 +236,11 @@ class MultiHeadAttention(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1]) dim_mapping=[1, -1])
...@@ -409,17 +409,17 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -409,17 +409,17 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1])
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1])
...@@ -482,12 +482,12 @@ class GPTEmbeddings(nn.Layer): ...@@ -482,12 +482,12 @@ class GPTEmbeddings(nn.Layer):
input_embedings = self.word_embeddings(input_ids) input_embedings = self.word_embeddings(input_ids)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
...@@ -715,10 +715,10 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -715,10 +715,10 @@ def gpt_pretrain_forward(train_program, start_program):
loss_mask = static.data( loss_mask = static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float64') name="loss_mask", shape=[batch_size, sequence_len], dtype='float64')
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
...@@ -750,8 +750,8 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -750,8 +750,8 @@ def gpt_pretrain_forward(train_program, start_program):
class TestGPTAutoCompletion(unittest.TestCase): class TestGPTAutoCompletion(unittest.TestCase):
def test_gpt_dp(self): def test_gpt_dp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -770,8 +770,8 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -770,8 +770,8 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_gpt_mp(self): def test_gpt_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -790,8 +790,8 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -790,8 +790,8 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context)) dist_context))
def test_gpt_dp_mp(self): def test_gpt_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
# The following statements are used to satisfy fleet initialization
import os
if os.getenv("CUDA_VISIBLE_DEVICES", None) is None:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
return out
def mlp_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
predict = mlp(input)
cost = layers.cross_entropy(input=predict, label=label)
avg_cost = layers.mean(x=cost)
return avg_cost, train_program, start_program
class TestMLPAutoParallelizer(unittest.TestCase):
def test_mlp_serial(self):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
loss, train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
# print_program_with_distributed_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program)
if __name__ == "__main__":
unittest.main()
...@@ -39,7 +39,7 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group ...@@ -39,7 +39,7 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process import new_process_group from paddle.distributed.auto_parallel.process import new_process_group
paddle.enable_static() paddle.enable_static()
_global_parallel_stratergy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
...@@ -156,12 +156,12 @@ class MLPLayer(nn.Layer): ...@@ -156,12 +156,12 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
...@@ -194,10 +194,10 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -194,10 +194,10 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, hidden_size], shape=[batch_size, sequence_len, hidden_size],
dtype='float32') dtype='float32')
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
...@@ -212,8 +212,8 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -212,8 +212,8 @@ def mlp_pretrain_forward(train_program, start_program):
class TestMLPAutoPartitioner(unittest.TestCase): class TestMLPAutoPartitioner(unittest.TestCase):
def test_mlp_dp(self): def test_mlp_dp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -238,13 +238,13 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -238,13 +238,13 @@ class TestMLPAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = [] var_need_broadcast = []
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context, initialization_check(_global_parallel_strategy, dist_context,
dist_startup_prog, serial_startup_prog, dist_startup_prog, serial_startup_prog,
var_need_broadcast)) var_need_broadcast))
def test_mlp_mp(self): def test_mlp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -285,13 +285,13 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -285,13 +285,13 @@ class TestMLPAutoPartitioner(unittest.TestCase):
var_need_broadcast = sorted( var_need_broadcast = sorted(
['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context, initialization_check(_global_parallel_strategy, dist_context,
dist_startup_prog, serial_startup_prog, dist_startup_prog, serial_startup_prog,
var_need_broadcast)) var_need_broadcast))
def test_mlp_dp_mp(self): def test_mlp_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
...@@ -332,7 +332,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -332,7 +332,7 @@ class TestMLPAutoPartitioner(unittest.TestCase):
var_need_broadcast = sorted( var_need_broadcast = sorted(
['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context, initialization_check(_global_parallel_strategy, dist_context,
dist_startup_prog, serial_startup_prog, dist_startup_prog, serial_startup_prog,
var_need_broadcast)) var_need_broadcast))
...@@ -373,10 +373,10 @@ class AttentionLayer(nn.Layer): ...@@ -373,10 +373,10 @@ class AttentionLayer(nn.Layer):
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
def forward(self, input): def forward(self, input):
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input, _global_process_mesh, dim_mapping=[0, -1, -1])
...@@ -387,14 +387,14 @@ class AttentionLayer(nn.Layer): ...@@ -387,14 +387,14 @@ class AttentionLayer(nn.Layer):
k = self.k_proj(input) k = self.k_proj(input)
v = self.v_proj(input) v = self.v_proj(input)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
...@@ -431,11 +431,11 @@ class AttentionLayer(nn.Layer): ...@@ -431,11 +431,11 @@ class AttentionLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1]) dim_mapping=[1, -1])
...@@ -467,8 +467,8 @@ def attn_pretrain_forward(train_program, start_program): ...@@ -467,8 +467,8 @@ def attn_pretrain_forward(train_program, start_program):
class TestAttentionAutoPartitioner(unittest.TestCase): class TestAttentionAutoPartitioner(unittest.TestCase):
def test_attn_dp(self): def test_attn_dp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -492,13 +492,13 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -492,13 +492,13 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = [] var_need_broadcast = []
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context, initialization_check(_global_parallel_strategy, dist_context,
dist_startup_prog, serial_startup_prog, dist_startup_prog, serial_startup_prog,
var_need_broadcast)) var_need_broadcast))
def test_attn_mp(self): def test_attn_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH) mesh=[0, 1, 2, 3], parent=ROOT_MESH)
...@@ -543,13 +543,13 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -543,13 +543,13 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = ['linear_3.b_0'] var_need_broadcast = ['linear_3.b_0']
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context, initialization_check(_global_parallel_strategy, dist_context,
dist_startup_prog, serial_startup_prog, dist_startup_prog, serial_startup_prog,
var_need_broadcast)) var_need_broadcast))
def test_attn_dp_mp(self): def test_attn_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
...@@ -594,7 +594,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -594,7 +594,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = ['linear_3.b_0'] var_need_broadcast = ['linear_3.b_0']
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context, initialization_check(_global_parallel_strategy, dist_context,
dist_startup_prog, serial_startup_prog, dist_startup_prog, serial_startup_prog,
var_need_broadcast)) var_need_broadcast))
...@@ -669,22 +669,22 @@ class DecoderLayer(nn.Layer): ...@@ -669,22 +669,22 @@ class DecoderLayer(nn.Layer):
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids): def forward(self, input_ids, position_ids):
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_embeddings = self.word_embeddings(input_ids) input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
...@@ -704,14 +704,14 @@ class DecoderLayer(nn.Layer): ...@@ -704,14 +704,14 @@ class DecoderLayer(nn.Layer):
k = self.k_proj(target) k = self.k_proj(target)
v = self.v_proj(target) v = self.v_proj(target)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
...@@ -749,11 +749,11 @@ class DecoderLayer(nn.Layer): ...@@ -749,11 +749,11 @@ class DecoderLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1]) dim_mapping=[1, -1])
...@@ -774,12 +774,12 @@ class DecoderLayer(nn.Layer): ...@@ -774,12 +774,12 @@ class DecoderLayer(nn.Layer):
out2 = F.gelu(out1, approximate=True) out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2) out3 = self.linear1(out2)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor( auto.shard_tensor(
...@@ -818,8 +818,8 @@ def decoder_pretrain_forward(train_program, start_program): ...@@ -818,8 +818,8 @@ def decoder_pretrain_forward(train_program, start_program):
class TestDecoderLayerPartitioner(unittest.TestCase): class TestDecoderLayerPartitioner(unittest.TestCase):
def test_decoder_dp_mp(self): def test_decoder_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
...@@ -877,13 +877,13 @@ class TestDecoderLayerPartitioner(unittest.TestCase): ...@@ -877,13 +877,13 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
'layer_norm_0.w_0', 'linear_5.b_0' 'layer_norm_0.w_0', 'linear_5.b_0'
]) ])
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context, initialization_check(_global_parallel_strategy, dist_context,
dist_startup_prog, serial_startup_prog, dist_startup_prog, serial_startup_prog,
var_need_broadcast)) var_need_broadcast))
def test_decoder_noparallel(self): def test_decoder_noparallel(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "None" _global_parallel_strategy = "None"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
......
...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.process import new_process_group ...@@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.process import new_process_group
paddle.enable_static() paddle.enable_static()
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
_global_parallel_stratergy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
...@@ -120,10 +120,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -120,10 +120,10 @@ class MultiHeadAttention(nn.Layer):
""" """
q = self.q_proj(query) q = self.q_proj(query)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
...@@ -157,19 +157,19 @@ class MultiHeadAttention(nn.Layer): ...@@ -157,19 +157,19 @@ class MultiHeadAttention(nn.Layer):
""" """
k = self.k_proj(key) k = self.k_proj(key)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
...@@ -250,11 +250,11 @@ class MultiHeadAttention(nn.Layer): ...@@ -250,11 +250,11 @@ class MultiHeadAttention(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1]) dim_mapping=[1, -1])
...@@ -423,17 +423,17 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -423,17 +423,17 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1])
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1])
...@@ -496,12 +496,12 @@ class GPTEmbeddings(nn.Layer): ...@@ -496,12 +496,12 @@ class GPTEmbeddings(nn.Layer):
input_embedings = self.word_embeddings(input_ids) input_embedings = self.word_embeddings(input_ids)
if _global_parallel_stratergy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
dim_mapping=[0, -1]) dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, _global_process_mesh,
...@@ -729,10 +729,10 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -729,10 +729,10 @@ def gpt_pretrain_forward(train_program, start_program):
loss_mask = static.data( loss_mask = static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float64') name="loss_mask", shape=[batch_size, sequence_len], dtype='float64')
if _global_parallel_stratergy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids, _global_process_mesh, dim_mapping=[0, -1])
...@@ -764,8 +764,8 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -764,8 +764,8 @@ def gpt_pretrain_forward(train_program, start_program):
class TestGPTPartitioner(unittest.TestCase): class TestGPTPartitioner(unittest.TestCase):
def test_gpt_dp_mp(self): def test_gpt_dp_mp(self):
global _global_parallel_stratergy global _global_parallel_strategy
_global_parallel_stratergy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册