From 121553587c7811eaaa38a9107c9870292796ae81 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Wed, 8 Sep 2021 16:34:27 +0800 Subject: [PATCH] [Auto Parallel] Integrate all modules (#35483) * add auto_parallel dir * mv to paddle.distributed * add shard_xx api * add distributed attrs for var * add ut, test=develop * add dist * update * update * update * update * update * update, test=develop * update, test=develop * update, test=develop * update, test=develop * update, test=develop * update, test=develop * update, test=develop * update * update * update * update * update * update, test=develop * update, test=develop * update * update * delete unused proto * resotre op_desc * restore type_defs * update var_desc * remove dimss_mapping for proto_pybind * update interface.py * update framework.py * update * update * add auto_parallel dir * mv to paddle.distributed * add shard_xx api * add distributed attrs for var * add ut, test=develop * [WIP] Add the auto completion feature and related codes * [WIP] Improve the auto completion and related codes * [WIP] Make the auto completion to support data-parallel * [WIP] Make the completion support mp and dp+mp * [WIP] Refactor auto completion unit test for MLP * [WIP] Refactor the implementation of DistributedOperatorImpl * [WIP] Improve dims_mapping update rule and fix a bug * [WIP] Support auto completion for one transformer decoder layer * [WIP] Add a minor change * [WIP] Fix a bug within the uint test * Shard XShape tensor, add embedding completion and refactor code * Add the distributed_operators dir to setup.py.in * Improve the completion process and add the unittest for gpt * fix process_mesh ut * fix process_mesh ut * update * update, test=develop * Add support for automatically completing distributed attrs of special ops * update * update * update * fix doc sample codes, test=develop * improve coverage, test=develop * add static_mode check, test=develop * Model the cluster for cost model and physical mapping * update, test=develop * add set_placement, test=develop * Add the check to make sure the candidate tensors' size is great than zero * update doc, test=develop * update doc, test=develop * update doc, test=develop * update doc, test=develop * update, test=develop * Auto mark dist attrs annotated by user * update ndarray to nested list, test=develop * update, test=develop * Add auto-completion module for auto-parallel (based on PR#33804) * Remove unnecessary files * Remove unrelated files for the auto completion pr * Update the unit test to improve the coverage * Modify codes based on reviews * Minor changes for CI * Improve some codes based on new comments * Fix bugs caused by shallow copy in attributes.py * Imporve amend_distributed_attr_for_program in context.py * Other changes for weihang's comments * support shard reader * support shard reader * add parallel mode * update process mesh * add method to compute comm_group * implement dist_embedding forward func * implement dist matmul forward func * implement dist reshape forward func * add transpiler framework * add transpiler forward * implement transpiler forward * implement transpiler backward & update * add process * add unitest * chmod * chmod * chmod * update unitest * add unitest for gpt * remove unused print * rename transpiler --> partitioner * rename transpiler --> partitioner * chmod * chmod * bug fixed * remove amp function * update case for dp mode * update case for dp mode * [Auto Parallel] Integrate all parts with the newest code * Integrate all parts of auto parallel and improve codes * Integrate all parts by AutoParallelizer * Add unit test for AutoParallelizer * Improve auto completion module for pipeline parallel * Add support for matmul_v2 in dist_matmul * Correct the typo "stratergy" to "strategy" * Modify distributed_strategy.proto to conform the main stream * Restore parts of distributed_strategy to conform the develop branch Co-authored-by: sandyhouse Co-authored-by: JZ-LIANG --- .../framework/distributed_strategy.proto | 1 + .../distributed/auto_parallel/completion.py | 170 +++++++++-- .../distributed/auto_parallel/context.py | 45 ++- .../auto_parallel/operators/dist_matmul.py | 271 +++++++++++++++++- .../distributed/auto_parallel/parallelizer.py | 79 +++++ .../distributed/auto_parallel/partitioner.py | 2 +- .../paddle/distributed/auto_parallel/utils.py | 5 +- .../fleet/base/distributed_strategy.py | 35 +++ .../distributed/fleet/base/fleet_base.py | 8 + .../test_auto_parallel_completion.py | 105 ++++--- .../test_auto_parallel_completion_gpt.py | 46 +-- .../test_auto_parallel_parallelizer.py | 138 +++++++++ .../test_auto_parallel_partitioner.py | 88 +++--- .../test_auto_parallel_partitioner_gpt.py | 38 +-- 14 files changed, 851 insertions(+), 180 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/parallelizer.py create mode 100755 python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 3627a8cf71c..4674ba4007f 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -202,6 +202,7 @@ message DistributedStrategy { optional bool calc_comm_same_stream = 32 [ default = false ]; optional bool asp = 33 [ default = false ]; optional bool fuse_grad_merge = 34 [ default = false ]; + optional bool semi_auto = 35 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 72af14af2c3..6e886d09d67 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -253,6 +253,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): if (not tensor_node.is_var()) or (tensor_node.var() is None): return False tensor_desc = tensor_node.var() + # Skip reader tensor + if tensor_desc.type() == core.VarDesc.VarType.READER: + return False tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_node) assert tensor_dist_attr is not None @@ -263,6 +266,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): dims_mapping_list = [] for pred_op_node in tensor_node.inputs: if pred_op_node.op() is not None: + if pred_op_node.op().type() == "create_py_reader" \ + or pred_op_node.op().type() == "create_double_buffer_reader" \ + or pred_op_node.op().type() == "read": + continue op_dist_attr = dist_context.get_op_distributed_attr_for_graph( pred_op_node) op_dims_mapping = op_dist_attr.get_output_dims_mapping( @@ -279,6 +286,10 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): dims_mapping_list = [] for succ_op_node in tensor_node.outputs: if succ_op_node.op() is not None: + if succ_op_node.op().type() == "create_py_reader" \ + or succ_op_node.op().type() == "create_double_buffer_reader" \ + or succ_op_node.op().type() == "read": + continue op_dist_attr = dist_context.get_op_distributed_attr_for_graph( succ_op_node) op_dims_mapping = op_dist_attr.get_input_dims_mapping( @@ -298,11 +309,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): changed = False if (not op_node.is_op()) or (op_node.op() is None): return False + # Skip reader op op_desc = op_node.op() + if op_desc.type() == "create_py_reader" \ + or op_desc.type() == "create_double_buffer_reader" \ + or op_desc.type() == "read": + return False op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) if fwd: for tensor_node in op_node.inputs: if tensor_node.var() is not None: + if tensor_node.var().type() == core.VarDesc.VarType.READER: + continue tensor_desc = tensor_node.var() if op_dist_attr.is_annotated_input_dims_mapping( tensor_desc.name()): @@ -344,6 +362,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): else: for tensor_node in op_node.outputs: if tensor_node.var() is not None: + if tensor_node.var().type() == core.VarDesc.VarType.READER: + continue tensor_desc = tensor_node.var() if op_dist_attr.is_annotated_output_dims_mapping( tensor_desc.name()): @@ -400,9 +420,8 @@ def complete_annotation(program, dist_context=None): if dist_context is None: dist_context = get_default_distributed_context() - # Initialize distributed attributes for all var and op node in program + # Initialize distributed attributes for all var and op node in program dist_context.initialize_distributed_attr_for_program(program) - # print_program_with_distributed_attr(program, dist_context) # Convert program to graph graph = framework.IrGraph(core.Graph(program.desc)) @@ -410,37 +429,134 @@ def complete_annotation(program, dist_context=None): # Initialize distributed attributes for all var and op node in graph dist_context.initialize_distributed_attr_for_graph(graph) - # # Complete process mesh for each node + # Complete process mesh for each node all_nodes = list(graph.all_nodes()) + + def sort_key_fun(node): + first = -1 + if node.is_op(): + first = 0 + else: + first = 1 + second = -1 + if node.is_op() and node.op() is not None: + second = node.op().id() + if node.is_var() and node.var() is not None: + second = node.var().id() + return (first, second) + + all_nodes.sort(key=sort_key_fun) + reach_fix_point = False while not reach_fix_point: - changed = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_process_mesh( - dist_context, node, fwd=True) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_process_mesh( - dist_context, node, fwd=True) - if op_changed: - changed = True - for node in reversed(all_nodes): - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_process_mesh( - dist_context, node, fwd=False) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_process_mesh( - dist_context, node, fwd=False) - if op_changed: - changed = True - if changed: + total_changed = False + reach_fwd_fix_point = False + reach_bwd_fix_point = False + while not reach_fwd_fix_point: + changed = False + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_changed = update_tensor_node_process_mesh( + dist_context, node, fwd=True) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = update_op_node_process_mesh( + dist_context, node, fwd=True) + if op_changed: + changed = True + if changed: + reach_fwd_fix_point = False + total_changed = True + else: + reach_fwd_fix_point = True + while not reach_bwd_fix_point: + changed = False + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_changed = update_tensor_node_process_mesh( + dist_context, node, fwd=False) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = update_op_node_process_mesh( + dist_context, node, fwd=False) + if op_changed: + changed = True + if changed: + reach_bwd_fix_point = False + total_changed = True + else: + reach_bwd_fix_point = True + if total_changed: reach_fix_point = False else: reach_fix_point = True + # Validation the completion of process meshes and should be moved to a proper location + is_wrong = False + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + node) + if tensor_dist_attr.get_process_mesh() is None: + msg_str = "" + for op_node in node.inputs: + if op_node.op() is not None: + op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_node) + msg_str += "{} [{}], ".format( + op_node.op().type(), + op_dist_attr.get_process_mesh()) + else: + msg_str += "{} [{}], ".format(op_node.name(), + None) + for op_node in node.outputs: + if op_node.op() is not None: + op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_node) + msg_str += "{} [{}], ".format( + op_node.op().type(), + op_dist_attr.get_process_mesh()) + else: + msg_str += "{} [{}], ".format(op_node.name(), + None) + msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format( + node.var().name(), msg_str[:-2]) + is_wrong = True + print(msg_str) + if node.is_op() and node.op() is not None: + op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + node) + if op_dist_attr.get_process_mesh() is None: + msg_str = "" + for tensor_node in node.inputs: + if tensor_node.var() is not None: + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + msg_str += "{} [{}], ".format( + tensor_node.var().name(), + tensor_dist_attr.get_process_mesh()) + else: + msg_str += "{} [{}], ".format( + tensor_node.name(), None) + for tensor_node in node.outputs: + if tensor_node.var() is not None: + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + msg_str += "{} [{}], ".format( + tensor_node.var().name(), + tensor_dist_attr.get_process_mesh()) + else: + msg_str += "{} [{}], ".format( + tensor_node.name(), None) + msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format( + node.op().type(), msg_str[:-2]) + is_wrong = True + print(msg_str) + if node.is_op() and node.op() is None: + print("op op is None", node.name()) + if is_wrong: + assert False, "Cannot complete process_meshes of the program." # Complete dims_mapping for each node reach_fix_point = False diff --git a/python/paddle/distributed/auto_parallel/context.py b/python/paddle/distributed/auto_parallel/context.py index bddf9368255..4958c5adfae 100644 --- a/python/paddle/distributed/auto_parallel/context.py +++ b/python/paddle/distributed/auto_parallel/context.py @@ -142,12 +142,15 @@ class DistributedContext: tensor.desc, tensor_dist_attr) self.set_tensor_distributed_attr_for_program( tensor, tensor_dist_attr) - tensor_dist_attr.set_shape(tensor.desc.shape()) + if tensor.type == core.VarDesc.VarType.READER: + tensor_dist_attr.set_shape([]) + else: + tensor_dist_attr.set_shape(tensor.desc.shape()) if tensor_dist_attr.get_process_mesh() is not None: tensor_dist_attr.mark_as_annotated("process_mesh") if tensor_dist_attr.get_dims_mapping() is None: tensor_dims_mapping = [ - -1 for _ in range(len(tensor.desc.shape())) + -1 for _ in range(len(tensor_dist_attr.get_shape())) ] tensor_dist_attr.set_dims_mapping(tensor_dims_mapping) else: @@ -168,12 +171,18 @@ class DistributedContext: op_dist_attr.mark_as_annotated("process_mesh") for tensor_name in op.input_arg_names: # There may be a better way to find the tensor by name - tensor = op.block._var_recursive(tensor_name) - op_dist_attr.set_input_shape(tensor_name, - tensor.desc.shape()) + if op.type == "create_py_reader" \ + or tensor.type == core.VarDesc.VarType.READER: + op_dist_attr.set_input_shape(tensor_name, []) + else: + tensor = op.block._var_recursive(tensor_name) + op_dist_attr.set_input_shape(tensor_name, + tensor.desc.shape()) if op_dist_attr.get_input_dims_mapping(tensor_name) is None: tensor_dims_mapping = [ - -1 for _ in range(len(tensor.desc.shape())) + -1 + for _ in range( + len(op_dist_attr.get_input_shape(tensor_name))) ] op_dist_attr.set_input_dims_mapping(tensor_name, tensor_dims_mapping) @@ -184,12 +193,18 @@ class DistributedContext: op_dist_attr.mark_as_parameter(tensor_name) for tensor_name in op.output_arg_names: tensor = op.block._var_recursive(tensor_name) - op_dist_attr.set_output_shape(tensor_name, - tensor.desc.shape()) + if tensor.type == core.VarDesc.VarType.READER: + op_dist_attr.set_output_shape(tensor_name, []) + else: + op_dist_attr.set_output_shape(tensor_name, + tensor.desc.shape()) if op_dist_attr.get_output_dims_mapping( tensor_name) is None: tensor_dims_mapping = [ - -1 for _ in range(len(tensor.desc.shape())) + -1 + for _ in range( + len( + op_dist_attr.get_output_shape(tensor_name))) ] op_dist_attr.set_output_dims_mapping( tensor_name, tensor_dims_mapping) @@ -378,8 +393,8 @@ class DistributedContext: # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and process_mesh_shape[dims_mapping[ - i]] > tensor_shape[i]: + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 for attr in self._op_distributed_attr_map_for_program.values(): @@ -392,8 +407,8 @@ class DistributedContext: # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and process_mesh_shape[ - dims_mapping[i]] > tensor_shape[i]: + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 for arg_name in attr.get_owner_op().desc.output_arg_names(): @@ -403,8 +418,8 @@ class DistributedContext: # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and process_mesh_shape[ - dims_mapping[i]] > tensor_shape[i]: + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 def _get_data_parallel_info(self): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 9059feeaf85..91bad5bc347 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -462,10 +462,271 @@ class DistributedMatmulV2(DistributedOperator): register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2")) +# ColumnParallel +class DistributedMatmulV2Impl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMatmulV2Impl0, self).__init__() + self._name = name + self._forward_implemented = True + self._backward_implemented = False + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + if is_dim_shard(x_dims_mapping[-1]): + return False + if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ + 1]): + return False + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if is_dim_replicate(out_dims_mapping[-1]): + return False + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + if dim_changed: + changed = True + return changed + + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['Y'] + ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( + input_name_mapping['Y']) + assert len( + output_name_mapping['Out'] + ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + X_var = dst_block.var(input_name_mapping['X'][0]) + Weight_var = dst_block.var(input_name_mapping['Y'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # TODO infer logic comm presentation + from ..process import new_process_group + from ..transpiler import _get_comm_group + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.topology, + model_parallel_axis, + process_mesh.process_group, rank_id) + group = new_process_group(group_ranks) + # print("@@@@@@@@@@@@@@@@@@@@@ 5", group) + + intermediate_var_0 = dst_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])), + dtype=X_var.dtype, + shape=X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_var.stop_gradient) + + check_variable_and_dtype( + X_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_identity') + + dst_block.append_op( + type='c_identity', + inputs={'X': [X_var]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], + 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = {'trans_x': False, 'trans_y': False} + inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} + dst_block.append_op( + type='matmul_v2', + inputs=inputs, + outputs={'Out': Out_var}, + attrs=attrs) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + + +# RowParallel +class DistributedMatmulV2Impl1(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMatmulV2Impl1, self).__init__() + self._name = name + self._forward_implemented = True + self._backward_implemented = False + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + if is_dim_replicate(x_dims_mapping[-1]): + return False + if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[ + -1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if is_dim_shard(out_dims_mapping[-1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + if dim_changed: + changed = True + return changed + + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['Y'] + ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( + input_name_mapping['Y']) + assert len( + output_name_mapping['Out'] + ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + X_var = dst_block.var(input_name_mapping['X'][0]) + Weight_var = dst_block.var(input_name_mapping['Y'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # TODO infer logic comm presentation + from ..process import new_process_group + from ..transpiler import _get_comm_group + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.topology, + model_parallel_axis, + process_mesh.process_group, rank_id) + group = new_process_group(group_ranks) + # print("@@@@@@@@@@@@@@@@@@@@@ 4", group) + + check_variable_and_dtype( + X_var, 'x', ['float16', 'float32', 'float64'], 'linear') + check_dtype(X_var.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = {'trans_x': False, 'trans_y': False} + inputs = {'X': X_var, 'Y': Weight_var} + intermediate_var_0 = dst_block.create_var( + shape=Out_var.shape, + dtype=Out_var.dtype, + type=Out_var.type, + lod_level=Out_var.lod_level, + persistable=False, + is_data=False, + need_check_feed=Out_var.desc.need_check_feed()) + dst_block.append_op( + type='matmul_v2', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs) + + dst_block.append_op( + type='c_allreduce_sum', + inputs={'X': intermediate_var_0}, + outputs={'Out': Out_var}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + + # ReplicateParallel -class DistributedMatmulV2Impl(DistributedOperatorImpl): +class DistributedMatmulV2Impl2(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulV2Impl, self).__init__() + super(DistributedMatmulV2Impl2, self).__init__() self._name = name def is_process_mesh_compatible(self, op_dist_attr): @@ -514,5 +775,9 @@ class DistributedMatmulV2Impl(DistributedOperatorImpl): return changed +register_distributed_operator_impl("matmul_v2", + DistributedMatmulV2Impl0("column_parallel")) +register_distributed_operator_impl("matmul_v2", + DistributedMatmulV2Impl1("row_parallel")) register_distributed_operator_impl( - "matmul_v2", DistributedMatmulV2Impl("replicate_parallel")) + "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py new file mode 100644 index 00000000000..2e36e92b344 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distributed.fleet import cloud_utils +from .context import DistributedContext +from .context import get_default_distributed_context +from .completion import complete_annotation +from .partitioner import Partitioner +from .process import get_all_process_groups + + +class AutoParallelizer: + """ + AutoParallelizer is the main controller class to do the auto parallel process. + And the auto parallel process will be triggered in the wrapped parallelize function. + To facilitate the auto parallelization, it will contain information about program, cluster and the + related context. In this basic version, the program information will be retrevied from + Fleet object, and the cluster information can be retrevied in the new created Cluster object, + and the context information can be retrevied in the new created DistributedContext. + """ + + def __init__(self, fleet): + self._fleet = fleet + self._optimizer = self._fleet.user_defined_optimizer + self._dist_strategy = self._fleet._user_defined_strategy + # self._dist_context = DistributedContext() + self._dist_context = get_default_distributed_context() + + def parallelize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + self._original_main_program = loss.block.program + # For now, we only allow user to use the default startup and main program + assert startup_program is not None + if startup_program == None: + self._original_startup_program = \ + paddle.static.default_startup_program().clone(for_test=False) + startup_program = paddle.static.default_startup_program() + else: + self._original_startup_program = \ + startup_program.clone(for_test=False) + + # Annotation completion + completed_main_program = complete_annotation( + self._original_main_program, self._dist_context) + + # Logical partition + rank = paddle.distributed.get_rank() + partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) + partitioned_main_prog, partitioned_startup_prog = partitioner.transpile_forward( + completed_main_program, startup_program) + dist_params_grads = partitioner.apply_backward( + loss, completed_main_program, startup_program, + partitioned_main_prog, partitioned_startup_prog) + dist_optimize_ops = partitioner.apply_optimize( + self._optimizer, dist_params_grads, partitioned_main_prog, + partitioned_startup_prog) + + # Traverse different rank programs and traverse each op of them, + # instantiate communication by process_mapping. + all_process_groups = get_all_process_groups() + for process_group in all_process_groups: + process_group.instantiate() + + return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 03497f2967c..b67f1e1ab97 100755 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -561,7 +561,7 @@ class Partitioner(object): if not var_dist_attr.is_parameter(): mapping = var_dist_attr.get_dims_mapping() mesh = var_dist_attr.get_process_mesh().topology - if mapping[0] >= 0 and mesh[mapping[0]] > 1: + if mapping and mapping[0] >= 0 and mesh[mapping[0]] > 1: self._enable_data_parallel = True break diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index c864375271b..547495fb848 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -79,11 +79,10 @@ def compute_compatible_process_mesh(process_mesh_list): return compatible_process_mesh for process_mesh in process_mesh_list: if process_mesh is not None: - if compatible_process_mesh is None: + if compatible_process_mesh is None or compatible_process_mesh == process_mesh: compatible_process_mesh = process_mesh else: - assert process_mesh == compatible_process_mesh, \ - "There is no compatible process mesh." + return None return compatible_process_mesh diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index d19cfd21698..378c2ff8d5d 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1596,6 +1596,41 @@ class DistributedStrategy(object): else: print("WARNING: auto should have value of bool type") + @property + def semi_auto(self): + """ + Indicating whether we are using semi-auto parallel function + This feature is currently an experimental feature. Currently, + auto-parallelism can be used only when a user does not set any other + strategy configs except semi-auto. For details, please reference the following + code example + Default Value: False + + Examples: + + .. code-block:: python + + import paddle + paddle.enable_static() + import paddle.distributed.fleet as fleet + + strategy = fleet.DistributedStrategy() + strategy.semi_auto = True + # if set other strategy at the same time, auto will not apply + # strategy.amp = True + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + """ + return self.strategy.semi_auto + + @semi_auto.setter + def semi_auto(self, flag): + if isinstance(flag, bool): + self.strategy.semi_auto = flag + else: + print("WARNING: semi-auto should have value of bool type") + @property def cudnn_exhaustive_search(self): """ diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index d1f6802919f..ceb1cf4e034 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1408,6 +1408,14 @@ class Fleet(object): context["origin_startup_program"] = startup_program context["role_maker"] = self._role_maker + # Use the auto-parallel's routines instead + if self._user_defined_strategy.semi_auto: + from ...auto_parallel.parallelizer import AutoParallelizer + auto_parallelizer = AutoParallelizer(self) + optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize( + loss, startup_program, parameter_list, no_grad_set) + return optimize_ops, params_grads, dist_startup_prog, dist_main_prog + # compile time distributed_optimizer_list = \ MetaOptimizerFactory()._get_valid_meta_optimizers( diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py index 9e1943ce6c6..21726596ca7 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -33,8 +33,9 @@ from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffi from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.context import set_default_distributed_context paddle.enable_static() -_global_parallel_stratergy = None +_global_parallel_strategy = None _global_process_mesh = None +_global_process_mesh2 = None ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) @@ -59,16 +60,22 @@ class MLPLayer(nn.Layer): self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + elif _global_parallel_strategy == "pp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh2, + dim_mapping=[1, -1]) out = self.norm(input) out = self.linear0(out) @@ -90,10 +97,10 @@ def mlp_pretrain_forward(train_program, start_program): shape=[batch_size, sequence_len, hidden_size], dtype='float32') - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) @@ -108,8 +115,8 @@ def mlp_pretrain_forward(train_program, start_program): class TestMLPAutoCompletion(unittest.TestCase): def test_mlp_dp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp" + global _global_parallel_strategy + _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -127,8 +134,8 @@ class TestMLPAutoCompletion(unittest.TestCase): dist_context)) def test_mlp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "mp" + global _global_parallel_strategy + _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -147,8 +154,8 @@ class TestMLPAutoCompletion(unittest.TestCase): dist_context)) def test_mlp_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) @@ -167,19 +174,26 @@ class TestMLPAutoCompletion(unittest.TestCase): dist_context)) def test_mlp_misc(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + # import pdb + global _global_parallel_strategy + _global_parallel_strategy = "pp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1], [2, 3]], parent=ROOT_MESH) + global _global_process_mesh2 + _global_process_mesh2 = auto.ProcessMesh( + mesh=[[4, 5], [6, 7]], parent=ROOT_MESH) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) + # pdb.set_trace() complete_train_program = auto.complete_annotation(train_program, dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) dist_context.finalize_distributed_attr_for_program( complete_train_program) from paddle.distributed.auto_parallel.interface import _g_process_mesh_map @@ -246,10 +260,10 @@ class AttentionLayer(nn.Layer): self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) def forward(self, input): - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) @@ -260,14 +274,14 @@ class AttentionLayer(nn.Layer): k = self.k_proj(input) v = self.v_proj(input) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( @@ -304,11 +318,11 @@ class AttentionLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -340,8 +354,8 @@ def attn_pretrain_forward(train_program, start_program): class TestAttentionAutoCompletion(unittest.TestCase): def test_attn_dp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp" + global _global_parallel_strategy + _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -359,8 +373,8 @@ class TestAttentionAutoCompletion(unittest.TestCase): dist_context)) def test_attn_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "mp" + global _global_parallel_strategy + _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -379,8 +393,8 @@ class TestAttentionAutoCompletion(unittest.TestCase): dist_context)) def test_attn_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) @@ -463,28 +477,29 @@ class DecoderLayer(nn.Layer): d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) self.linear1 = nn.Linear( dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) - self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) self.dropout1 = nn.Dropout(self.dropout_ratio) self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") def forward(self, input_ids, position_ids): - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, @@ -494,7 +509,7 @@ class DecoderLayer(nn.Layer): embeddings = self.dropout1(embeddings) # Pre-norm - target = self.norm(embeddings) + target = self.norm1(embeddings) # The following is the attention part q = self.q_proj(target) @@ -504,14 +519,14 @@ class DecoderLayer(nn.Layer): k = self.k_proj(target) v = self.v_proj(target) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( @@ -549,11 +564,11 @@ class DecoderLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -562,19 +577,19 @@ class DecoderLayer(nn.Layer): residual = embeddings + self.dropout2(out) # Pre-norm - out0 = self.norm(residual) + out0 = self.norm2(residual) # The following is the MLP part out1 = self.linear0(out0) out2 = F.gelu(out1, approximate=True) out3 = self.linear1(out2) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( @@ -613,8 +628,8 @@ def decoder_pretrain_forward(train_program, start_program): class TestDecoderLayerAutoCompletion(unittest.TestCase): def test_decoder_dp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp" + global _global_parallel_strategy + _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -632,8 +647,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): dist_context)) def test_decoder_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "mp" + global _global_parallel_strategy + _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -652,8 +667,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): dist_context)) def test_decoder_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py index 204e8910e05..cd87a72a7e6 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py @@ -36,7 +36,7 @@ from paddle.distributed.auto_parallel.utils import print_program_with_distribute from paddle.distributed.auto_parallel.context import DistributedContext paddle.enable_static() -_global_parallel_stratergy = None +_global_parallel_strategy = None _global_process_mesh = None ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) @@ -106,10 +106,10 @@ class MultiHeadAttention(nn.Layer): """ q = self.q_proj(query) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) @@ -143,19 +143,19 @@ class MultiHeadAttention(nn.Layer): """ k = self.k_proj(key) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) v = self.v_proj(value) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) @@ -236,11 +236,11 @@ class MultiHeadAttention(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -409,17 +409,17 @@ class TransformerDecoderLayer(nn.Layer): if self.normalize_before: tgt = self.norm2(tgt) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -482,12 +482,12 @@ class GPTEmbeddings(nn.Layer): input_embedings = self.word_embeddings(input_ids) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, @@ -715,10 +715,10 @@ def gpt_pretrain_forward(train_program, start_program): loss_mask = static.data( name="loss_mask", shape=[batch_size, sequence_len], dtype='float64') - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) @@ -750,8 +750,8 @@ def gpt_pretrain_forward(train_program, start_program): class TestGPTAutoCompletion(unittest.TestCase): def test_gpt_dp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp" + global _global_parallel_strategy + _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -770,8 +770,8 @@ class TestGPTAutoCompletion(unittest.TestCase): dist_context)) def test_gpt_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "mp" + global _global_parallel_strategy + _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -790,8 +790,8 @@ class TestGPTAutoCompletion(unittest.TestCase): dist_context)) def test_gpt_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py new file mode 100755 index 00000000000..6db7fbf8075 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +# The following statements are used to satisfy fleet initialization +import os +if os.getenv("CUDA_VISIBLE_DEVICES", None) is None: + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +from paddle.fluid import layers +from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr + +paddle.enable_static() +_global_parallel_strategy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([0, 1]) + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + out = self.linear2(out) + + return out + + +def mlp_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + label = static.data( + name="label", shape=[batch_size, sequence_len, 1], dtype='float32') + + auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + + predict = mlp(input) + + cost = layers.cross_entropy(input=predict, label=label) + avg_cost = layers.mean(x=cost) + + return avg_cost, train_program, start_program + + +class TestMLPAutoParallelizer(unittest.TestCase): + def test_mlp_serial(self): + + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.amp = False + dist_strategy.pipeline = False + dist_strategy.recompute = False + + # init parallel optimizer + dist_strategy.semi_auto = True + + fleet.init(is_collective=True, strategy=dist_strategy) + + train_program = static.Program() + start_program = static.Program() + loss, train_program, start_program = mlp_pretrain_forward(train_program, + start_program) + + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + optimizer = fleet.distributed_optimizer(optimizer) + _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( + loss, start_program) + # print_program_with_distributed_attr(distributed_main_program) + self.assertIsNotNone(distributed_startup_program) + self.assertIsNotNone(distributed_main_program) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index f1049084cfb..18dcc36fe0e 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -39,7 +39,7 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.process import new_process_group paddle.enable_static() -_global_parallel_stratergy = None +_global_parallel_strategy = None _global_process_mesh = None ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) @@ -156,12 +156,12 @@ class MLPLayer(nn.Layer): self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( @@ -194,10 +194,10 @@ def mlp_pretrain_forward(train_program, start_program): shape=[batch_size, sequence_len, hidden_size], dtype='float32') - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) @@ -212,8 +212,8 @@ def mlp_pretrain_forward(train_program, start_program): class TestMLPAutoPartitioner(unittest.TestCase): def test_mlp_dp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp" + global _global_parallel_strategy + _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -238,13 +238,13 @@ class TestMLPAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = [] self.assertTrue( - initialization_check(_global_parallel_stratergy, dist_context, + initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_mlp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "mp" + global _global_parallel_strategy + _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -285,13 +285,13 @@ class TestMLPAutoPartitioner(unittest.TestCase): var_need_broadcast = sorted( ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) self.assertTrue( - initialization_check(_global_parallel_stratergy, dist_context, + initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_mlp_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) @@ -332,7 +332,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): var_need_broadcast = sorted( ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) self.assertTrue( - initialization_check(_global_parallel_stratergy, dist_context, + initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) @@ -373,10 +373,10 @@ class AttentionLayer(nn.Layer): self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) def forward(self, input): - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) @@ -387,14 +387,14 @@ class AttentionLayer(nn.Layer): k = self.k_proj(input) v = self.v_proj(input) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( @@ -431,11 +431,11 @@ class AttentionLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -467,8 +467,8 @@ def attn_pretrain_forward(train_program, start_program): class TestAttentionAutoPartitioner(unittest.TestCase): def test_attn_dp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp" + global _global_parallel_strategy + _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -492,13 +492,13 @@ class TestAttentionAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = [] self.assertTrue( - initialization_check(_global_parallel_stratergy, dist_context, + initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_attn_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "mp" + global _global_parallel_strategy + _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) @@ -543,13 +543,13 @@ class TestAttentionAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = ['linear_3.b_0'] self.assertTrue( - initialization_check(_global_parallel_stratergy, dist_context, + initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_attn_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) @@ -594,7 +594,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = ['linear_3.b_0'] self.assertTrue( - initialization_check(_global_parallel_stratergy, dist_context, + initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) @@ -669,22 +669,22 @@ class DecoderLayer(nn.Layer): self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") def forward(self, input_ids, position_ids): - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, @@ -704,14 +704,14 @@ class DecoderLayer(nn.Layer): k = self.k_proj(target) v = self.v_proj(target) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( @@ -749,11 +749,11 @@ class DecoderLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -774,12 +774,12 @@ class DecoderLayer(nn.Layer): out2 = F.gelu(out1, approximate=True) out3 = self.linear1(out2) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( @@ -818,8 +818,8 @@ def decoder_pretrain_forward(train_program, start_program): class TestDecoderLayerPartitioner(unittest.TestCase): def test_decoder_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) @@ -877,13 +877,13 @@ class TestDecoderLayerPartitioner(unittest.TestCase): 'layer_norm_0.w_0', 'linear_5.b_0' ]) self.assertTrue( - initialization_check(_global_parallel_stratergy, dist_context, + initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_decoder_noparallel(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "None" + global _global_parallel_strategy + _global_parallel_strategy = "None" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index b02c5f8a84f..16cbad3ef6f 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -40,7 +40,7 @@ from paddle.distributed.auto_parallel.process import new_process_group paddle.enable_static() ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) -_global_parallel_stratergy = None +_global_parallel_strategy = None _global_process_mesh = None @@ -120,10 +120,10 @@ class MultiHeadAttention(nn.Layer): """ q = self.q_proj(query) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) @@ -157,19 +157,19 @@ class MultiHeadAttention(nn.Layer): """ k = self.k_proj(key) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) v = self.v_proj(value) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) @@ -250,11 +250,11 @@ class MultiHeadAttention(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -423,17 +423,17 @@ class TransformerDecoderLayer(nn.Layer): if self.normalize_before: tgt = self.norm2(tgt) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) @@ -496,12 +496,12 @@ class GPTEmbeddings(nn.Layer): input_embedings = self.word_embeddings(input_ids) - if _global_parallel_stratergy == "mp": + if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, @@ -729,10 +729,10 @@ def gpt_pretrain_forward(train_program, start_program): loss_mask = static.data( name="loss_mask", shape=[batch_size, sequence_len], dtype='float64') - if _global_parallel_stratergy == "dp": + if _global_parallel_strategy == "dp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) - elif _global_parallel_stratergy == "dp_mp": + elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) @@ -764,8 +764,8 @@ def gpt_pretrain_forward(train_program, start_program): class TestGPTPartitioner(unittest.TestCase): def test_gpt_dp_mp(self): - global _global_parallel_stratergy - _global_parallel_stratergy = "dp_mp" + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( -- GitLab