From e3faf3456c3abae4a8a2de4039876dbb88d328a8 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 29 Dec 2021 15:36:17 +0800 Subject: [PATCH] [Auto Parallel] Sharding Pass (#38502) * auto parallel sharding base * chmod * add unitest * set unitest cmake dist label * revise code according to rewiew * chmod --- .../framework/distributed_strategy.proto | 1 + .../auto_parallel/operators/common.py | 9 + .../auto_parallel/operators/dist_default.py | 10 +- .../auto_parallel/operators/dist_embedding.py | 61 +- .../auto_parallel/operators/dist_matmul.py | 11 +- .../distributed/auto_parallel/parallelizer.py | 37 +- .../distributed/auto_parallel/partitioner.py | 56 +- .../paddle/distributed/auto_parallel/utils.py | 62 +- python/paddle/distributed/passes/__init__.py | 1 + .../passes/auto_parallel_sharding.py | 694 ++++++++++++++++++ .../distributed_passes/CMakeLists.txt | 1 + .../auto_parallel_pass_test_base.py | 10 +- .../test_auto_parallel_sharding_pass.py | 70 ++ 13 files changed, 931 insertions(+), 92 deletions(-) create mode 100644 python/paddle/distributed/passes/auto_parallel_sharding.py mode change 100644 => 100755 python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 7380e0f129..28108e78d9 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -45,6 +45,7 @@ message ShardingConfig { optional bool optimize_cast = 12 [ default = false ]; // Optimizer sharding. Temporary plans and may be deprecated optional bool _dp_as_optimizer_sharding = 13 [ default = false ]; + optional int32 stage = 14 [ default = 1 ]; } message HybridConfig { diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 0e0b2eae9c..32496b94b9 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -15,6 +15,7 @@ from ..dist_attribute import OperatorDistributedAttribute _g_distributed_operator_impl_registries = {} +BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale'} class DistributedOperatorImplContainer: @@ -116,6 +117,14 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): return best_compatible_impl, idx +def is_parameter_related(varname, block): + if ".cast_fp" in varname: + varname = varname[:varname.index(".cast_fp")] + assert block.has_var(varname) + var = block.var(varname) + return var.is_parameter + + def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): var_shape = block.var(src_var.name).shape var_topoloy = src_var_dist_attr.process_mesh.topology diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 72e750e5a4..e2ebf1cfe6 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -15,7 +15,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container -from .common import register_distributed_operator_impl +from .common import register_distributed_operator_impl, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -183,8 +183,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): need_gradient_allreduce = False for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): - if "@GRAD" not in varname and not main_block.var( - varname).is_parameter: + if "@GRAD" not in varname and not is_parameter_related( + varname, main_block): # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op process_mesh = dist_attr.process_mesh @@ -210,8 +210,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): allreduce_vars = [] for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): - if "@GRAD" not in varname and main_block.var( - varname).is_parameter: + if "@GRAD" not in varname and is_parameter_related( + varname, main_block): assert len( backward_op.desc.input(input_name) ) == 1, "parameter input to grad op should be length 1, but got [{}]".format( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 18d976e965..866fed1ae6 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -16,7 +16,7 @@ from .common import infer_shape from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container -from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program +from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -26,7 +26,7 @@ from ..utils import compute_compatible_and_update_dim_mapping from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode -from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.framework import Program, Parameter, Variable from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process_group import new_process_group @@ -283,34 +283,35 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): allreduce_op_dist_attr) # param initialization sync - assert Weight_var.name not in dist_op_context.already_init_sync_vars - dist_op_context.already_init_sync_vars.add(Weight_var.name) - param = startup_block.var(Weight_var.name) - param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) - process_mesh = param_dist_attr.process_mesh - dim_mapping = param_dist_attr.dims_mapping - - # NOTE all not splited axis should be presented in mesh - for axis, size in enumerate(process_mesh.topology): - if size <= 1 or axis in dim_mapping: - pass - else: - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, axis, - rank_id) - sync_group = new_process_group(group_ranks) - - startup_block.append_op( - type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': sync_group.id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) - startup_block._sync_with_cpp() + if Weight_var.is_parameter: + assert Weight_var.name not in dist_op_context.already_init_sync_vars + dist_op_context.already_init_sync_vars.add(Weight_var.name) + param = startup_block.var(Weight_var.name) + param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) + process_mesh = param_dist_attr.process_mesh + dim_mapping = param_dist_attr.dims_mapping + + # NOTE all not splited axis should be presented in mesh + for axis, size in enumerate(process_mesh.topology): + if size <= 1 or axis in dim_mapping: + pass + else: + group_ranks = _get_comm_group(process_mesh.processes, + process_mesh.topology, axis, + rank_id) + sync_group = new_process_group(group_ranks) + + startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + startup_block._sync_with_cpp() @staticmethod def backward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index aeaf9eb76b..9b0bdabc6d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -18,7 +18,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl -from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program +from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -184,7 +184,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): Out_grad = main_block.var(kwargs['Out@GRAD'][0]) Y_grad = main_block.var(kwargs['Y@GRAD'][0]) - assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( + assert not is_parameter_related( + X_var.name, main_block + ), "left operand(X) [{}] of dist matmul should not be parameter".format( X_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) @@ -200,7 +202,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): Y_var_partitioned = True break - if Y_var.is_parameter and Y_var_partitioned: + if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned: if Y_var_dim_mapping[0] >= 0: # row parallel: c_identity + matmul @@ -322,7 +324,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): dp_degree = len(group_ranks) dp_group = new_process_group(group_ranks) - if need_gradient_allreduce and Y_var.is_parameter: + if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block): Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) allreduce_op = main_block.append_op( type='c_allreduce_sum', @@ -444,6 +446,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): y_dims_mapping), "now just support x dims > y dims" if len(y_dims_mapping) != 2: return False + if len(x_dims_mapping) == len(y_dims_mapping) and len( x_dims_mapping) == 4: if x_dims_mapping[:2] != y_dims_mapping[:2]: diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 1b4fcd6983..04d5f1db59 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -27,6 +27,7 @@ from paddle.distributed.utils import get_logger from paddle.distributed.fleet import cloud_utils import paddle.fluid.core as core from paddle.fluid import program_guard +from paddle.distributed.passes import new_pass, PassContext from .dist_context import DistributedContext from .dist_context import get_default_distributed_context from .dist_context import set_default_distributed_context @@ -139,23 +140,9 @@ class AutoParallelizer: def _apply_optimize(self, main_program, startup_program, params_grads): - if self._dist_strategy.sharding: - auto_parallel_sharding_pass = new_pass( - "auto_parallel_sharding_pass", self._dist_strategy) - params_grads = auto_parallel_sharding_pass.apply( - main_program, startup_program, params_grads, self._pass_context) - - if self._dist_strategy.gradient_merge: - auto_parallel_gradient_merge_pass = new_pass( - "auto_parallel_gradient_merge_pass", - self._dist_strategy.gradient_merge_configs) - auto_parallel_gradient_merge_pass.apply( - main_program, startup_program, params_grads, self._pass_context) - - else: - with program_guard(main_program, startup_program): - optimizer = copy.deepcopy(self._optimizer) - optimize_ops = optimizer.apply_gradients(params_grads) + with program_guard(main_program, startup_program): + optimize_ops = copy.deepcopy(self._optimizer).apply_gradients( + params_grads) # update completion complete_update_annotation( @@ -163,6 +150,19 @@ class AutoParallelizer: return optimize_ops + def _apply_post_optimization_passed(self, main_program, startup_program, + rank, params_grads): + + if self._dist_strategy.sharding: + config = copy.deepcopy(self._dist_strategy.sharding_configs) + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + config["global_rank"] = rank + auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", + config) + auto_parallel_sharding_pass.apply( + [main_program], [startup_program], self._pass_context) + def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): completed_main_program = None serial_main_program = self._main_program.clone() @@ -203,7 +203,8 @@ class AutoParallelizer: make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) - + self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog, + rank, dist_params_grads) g_process_group_map = None if not relaunch_phase: g_process_group_map = copy.deepcopy(_g_process_group_map) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index e4d913cb9c..096de1c206 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -24,7 +24,8 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext from .dist_attribute import OperatorDistributedAttribute from .process_group import new_process_group -from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op +from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op +from .operators.common import BACKWARD_ONLY_DIST_OPS __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -102,22 +103,17 @@ class Partitioner(object): partitioned_startup_prog = fluid.Program() ref_block = serial_main_program.global_block() target_block = partitioned_startup_prog.global_block() - param2shape = {} + var2shape = {} temp_varname_map = {} # tensors for var in serial_startup_program.list_vars(): - if isinstance(var, Parameter): - # TODO if var not belong to this rank, should be filtered - serial_main_var = ref_block.var(var.name) - dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - serial_main_var) - target_shape = _get_dist_shape(serial_main_var, dist_attr) - new_name = var.name + self._dist_varname_suffix - temp_varname_map[var.name] = new_name - _partition_parameter(self._dist_context, serial_main_var, - target_block, new_name, target_shape) - param2shape[new_name] = target_shape + assert var.persistable + new_name = var.name + self._dist_varname_suffix + temp_varname_map[var.name] = new_name + target_shape = _partition_var(self._dist_context, ref_block, + target_block, var.name, new_name) + var2shape[new_name] = target_shape # ops for op in serial_startup_program.global_block().ops: @@ -128,14 +124,14 @@ class Partitioner(object): ) == 1, "initializer should output only ONE variable, but got [{}]".format( str(op.desc)) assert temp_varname_map[output_vars[ - 0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( + 0]] in var2shape, "try to initialize [{}] which is not a persistable var".format( output_vars[0]) new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op.desc) new_op_desc._rename_output(output_vars[0], temp_varname_map[output_vars[0]]) new_op_desc._set_attr("shape", - param2shape[temp_varname_map[output_vars[0]]]) + var2shape[temp_varname_map[output_vars[0]]]) target_block._sync_with_cpp() # set distribute atrribute @@ -211,7 +207,6 @@ class Partitioner(object): **koutputs) elif is_backward_op(op): - print(str(op)) kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_backward_impl = _get_dist_op_backward_implement( op, self._dist_context, forward_op_id2forward_op) @@ -351,6 +346,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, name=dst_varname, persistable=True, stop_gradient=True) + target_shape = None else: dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) target_shape = _get_dist_shape(src_var, dist_attr) @@ -361,6 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, else: _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, target_shape) + return target_shape def _get_dist_op_backward_implement(backward_op, dist_context, @@ -371,25 +368,32 @@ def _get_dist_op_backward_implement(backward_op, dist_context, forward_op = forward_op_id2forward_op[forward_op_id] forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op) - dist_ops = get_distributed_operator_impl_container(forward_op.type) + dist_op = get_distributed_operator_impl_container(forward_op.type) # TODO backward should have its own impl_idx - if dist_ops and forward_op_dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ + if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \ forward_op_dist_attr.impl_idx)._backward_implemented: - return dist_ops.get_impl(forward_op_dist_attr.impl_idx) + return dist_op.get_impl(forward_op_dist_attr.impl_idx) - dist_ops = get_distributed_operator_impl_container("default") - return dist_ops.get_impl(0) + # NOTE trick for dist ops that only have backward implement + if backward_op.type in BACKWARD_ONLY_DIST_OPS: + op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) + assert op_dist_attr.impl_idx >= 0 + return get_distributed_operator_impl_container( + backward_op.type).get_impl(op_dist_attr.impl_idx) + + dist_op = get_distributed_operator_impl_container("default") + return dist_op.get_impl(0) def _get_dist_op_forward_implement(forward_op, dist_context): dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) - dist_ops = get_distributed_operator_impl_container(forward_op.type) + dist_op = get_distributed_operator_impl_container(forward_op.type) - if dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( + if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl( dist_attr.impl_idx)._forward_implemented: - return dist_ops.get_impl(dist_attr.impl_idx) + return dist_op.get_impl(dist_attr.impl_idx) else: - dist_ops = get_distributed_operator_impl_container("default") - return dist_ops.get_impl(0) + dist_op = get_distributed_operator_impl_container("default") + return dist_op.get_impl(0) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 8f7ac36040..5198b8f5fd 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -25,6 +25,7 @@ import paddle.fluid.core as core from paddle.framework.io import _to_LodTensor from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.fluid.io import is_parameter, is_belong_to_optimizer +from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute def is_valid_list_index(list, index): @@ -993,18 +994,23 @@ def set_grad_var_shape(program, dist_context): block = program.global_block() vars = block.vars for op in block.ops: - if op.type in [ - "sum", "check_finite_and_unscale", "update_loss_scaling" - ]: + + if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: + break + + if op.type in ["sum"]: continue if int(op.attr('op_role')) == int(OpRole.Backward): op_dist_attr = dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr is not None for var_name in op.output_arg_names: + assert "@GRAD" in var_name forward_var_name = var_name[:var_name.find("@GRAD")] - if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale": + if op.type in [ + "c_allreduce_sum", "c_identity", "scale", "cast" + ]: forward_var_name = op.input_arg_names[0] elif op.type == "matmul_v2_grad": forward_var_name = None @@ -1038,6 +1044,7 @@ def set_grad_var_shape(program, dist_context): forward_input_dist_attr = op_dist_attr.get_input_dist_attr( forward_var_name) + assert forward_input_dist_attr is not None, f"{forward_var_name}" forward_var = vars[forward_var_name] forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( @@ -1069,6 +1076,53 @@ def is_backward_op(op): int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) +def is_recompute_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) == 9 + + +def is_loss_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) + + +def get_loss_op(block): + loss_ops = [] + for op in block.ops: + if is_loss_op(op): + assert len(op.desc.output_arg_names( + )) == 1, "loss op should only output loss var" + loss_ops.append(op) + + assert len(loss_ops) == 1, "num of loss op is not equal to one" + return loss_ops[0] + + +def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = dims_mapping + # TODO get global mesh group + tensor_dist_attr.process_mesh = process_mesh + dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) + return tensor_dist_attr + + +def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(new_op, process_mesh, + ref_mapping, ctx): + assert process_mesh is not None + assert ref_mapping is not None + + new_op_dist_attr = OperatorDistributedAttribute() + + for input_varname in new_op.desc.input_arg_names(): + new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping) + for output_varname in new_op.desc.output_arg_names(): + new_op_dist_attr.set_output_dims_mapping(output_varname, ref_mapping) + + new_op_dist_attr.process_mesh = process_mesh + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + def update_op_dims_mapping_by_default_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 55c90abf14..a5e9b76334 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -14,6 +14,7 @@ from .pass_base import new_pass, PassManager, PassContext from .fuse_all_reduce import * +from .auto_parallel_sharding import * from .cpp_pass import * __all__ = [ diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py new file mode 100644 index 0000000000..5e799c5209 --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -0,0 +1,694 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce +from collections import OrderedDict +import numpy as np + +import paddle +from paddle.framework import core +from paddle.fluid import unique_name +from .pass_base import PassBase, register_pass +from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op +from paddle.distributed.auto_parallel.process_group import get_world_process_groups, new_process_group +from paddle.distributed.auto_parallel.operators.common import is_parameter_related +from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr + +OpRole = core.op_proto_and_checker_maker.OpRole +OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() +_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice'] +# update here to support new optimizers +_supported_optimizer_type = [ + "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum", + "lars_momentum", "merged_momentum", "lamb", "sgd" +] + + +# NOTE we add the "auto_parallel" prefix to the pass in order to +# indicate that this pass should obey some constrains by auto_parallel +# for example all ops and vars should has dist attr before and after pass +# should use dist op instead of custom comm op +@register_pass("auto_parallel_sharding") +class ShardingPass(PassBase): + def __init__(self): + super(ShardingPass, self).__init__() + self.set_attr("dist_context", None) + self.set_attr("stage", None) + self.set_attr("sharding_degree", None) + self.set_attr("params_grads", []) + self.set_attr("global_rank", -1) + self.dp_groups = set() + self.sharding_infos = [] + self.varname_to_sharding_info = {} + self.partial_sharding = False + self.outer_dp_group = None + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + + if self.get_attr("stage") not in [1, 2, 3]: + return False + if (not isinstance(self.get_attr("sharding_degree"), + int)) or self.get_attr("sharding_degree") <= 1: + return False + if len(self.get_attr("params_grads")) <= 0: + return False + if (not isinstance(self.get_attr("global_rank"), + int)) or self.get_attr("global_rank") < 0: + return False + + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + self._dist_context = self.get_attr("dist_context") + self.sharding_world_size = int(self.get_attr("sharding_degree")) + self.stage = int(self.get_attr("stage")) + self.global_rank = int(self.get_attr("global_rank")) + params_grads = self.get_attr("params_grads") + main_block, startup_block = main_program.global_block( + ), startup_program.global_block() + + self._build_sharding_groups(main_block, params_grads) + self._shard_optimizer(main_block, startup_block, params_grads, context) + self._shard_gradient_synchronization(main_block) + self._shard_parameter(main_block, startup_block) + + def _build_sharding_groups(self, main_block, params_grads): + self._collective_data_parallel_groups(main_block) + self._build_sharding_infos(params_grads) + + def _collective_data_parallel_groups(self, main_block): + for op in main_block.ops: + if op.type in _skip_ops: + continue + group = _inference_data_parallel_group_for_operator( + self.global_rank, op, self._dist_context) + if group is not None: + self.dp_groups.add(group) + + # TODO(JZ-LIANG) allow more than one dp groups in network, support more general distribution + # genetated by auto search + if len(self.dp_groups) != 1: + raise NotImplementedError( + "So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups". + format(len(groups))) + + def _build_sharding_infos(self, params_grads): + + for dp_group in self.dp_groups: + + assert dp_group.nranks >= self.sharding_world_size, "sharding world size [{}] should not larger than dp world size [{}]".format( + self.sharding_world_size, dp_group.nranks) + assert dp_group.nranks % self.sharding_world_size == 0, "sharding world size [{}] should be divisible by dp world size [{}]".format( + self.sharding_world_size, dp_group.nranks) + assert self.global_rank in dp_group.ranks, "current ranks [{}] does NOT belong to the data parallel group [{}]".format( + self.global_rank, dp_group.ranks) + assert len( + params_grads + ) >= self.sharding_world_size, "number of parameters [{}] is not enough to be shard among [{}] ranks".format( + len(params_grads), self.sharding_world_size) + + # sharding hybrid data parallel: partial sharding param within + if dp_group.nranks > self.sharding_world_size: + self.partial_sharding = True + assert len( + self.dp_groups + ) == 1, "hybrid sharding and data parallelism are supported only when there is excatly one data parallel group in the network" + outer_dp_group, sharding_group = _get_dp_and_sharding_groups( + dp_group.ranks, self.sharding_world_size, self.global_rank) + sharding_group = new_process_group(sharding_group) + self.outer_dp_group = new_process_group(outer_dp_group) + else: + sharding_group = dp_group + + # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group + params_in_group = [p for p, g in params_grads] + assert len(params_in_group) == len(set( + params_in_group)), "found duplicated param in params_grads" + sharding_info = ShardingInfo(sharding_group, self.global_rank, + params_in_group) + self.sharding_infos.append(sharding_info) + for param in params_in_group: + self.varname_to_sharding_info[param.name] = sharding_info + + def _shard_optimizer(self, main_block, startup_block, params_grads, + pass_context): + """ + sharding all optimizer related ops and vars, include: + gradient clip ops & vars + weight decay ops & vars + optimizer ops and states + """ + self._shard_amp_related_op_and_vars(main_block, pass_context) + self._shard_weight_decay(main_block) + self._shard_gradient_clip(main_block) + self._shard_optimizer_ops_and_states(main_block, startup_block) + self._insert_optimizer_broadcasts(main_block, startup_block) + + def _shard_amp_related_op_and_vars(self, main_block, pass_context): + + if self.stage < 2: + return + + for idx, op in reversed(list(enumerate(main_block.ops))): + # shard amp related param_grad cast + if _is_param_grad_fp32_cast_op(main_block, op): + output_name = op.output_arg_names[0] + param_name = output_name[:output_name.find("@")] + if not self._is_parameter_in_local_shard(param_name): + main_block._remove_op(idx, sync=False) + main_block._remove_var(output_name, sync=False) + + # shard check nan inf + elif op.type in ["check_finite_and_unscale", "update_loss_scaling"]: + reversed_x = [] + for input_name in op.desc.input('X'): + param_name = input_name[:input_name.find("@")] + + if self._is_parameter_in_local_shard(param_name): + reversed_x.append(input_name) + op.desc.set_input('X', reversed_x) + op.desc.set_output('Out', reversed_x) + + main_block._sync_with_cpp() + + def _shard_gradient_clip(self, main_block): + + if self.stage < 2: + return + + # TODO (JZ-LIANG) support calculate global norm with tensor parallelism + is_clip_grad_by_global_norm = False + for idx, op in list(enumerate(main_block.ops)): + if not _is_gradient_clip_op(op): + continue + if op.type == 'sum': + is_clip_grad_by_global_norm = True + break + if not is_clip_grad_by_global_norm: + return + + removed_op_idx = set() + removed_tmp_var = set() + for idx, op in list(enumerate(main_block.ops)): + if not _is_gradient_clip_op(op): + continue + if op.type == 'sum': + reserved_vars = [] + for input_name in op.input_arg_names: + if input_name not in removed_tmp_var: + reserved_vars.append(input_name) + op.desc.set_input("X", reserved_vars) + + sum_op_output = op.desc.output_arg_names()[0] + for i, sharding_info in enumerate(self.sharding_infos): + new_op = main_block._insert_op( + idx + i, + type='c_allreduce_sum', + inputs={'X': [sum_op_output]}, + outputs={'Out': [sum_op_output]}, + attrs={ + 'ring_id': sharding_info.group.id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }) + dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + main_block.var(sum_op_output)) + assert dist_attr is not None + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_op, dist_attr.process_mesh, dist_attr.dims_mapping, + self._dist_context) + break + for input_name in op.input_arg_names: + param_name = input_name[:input_name.find("@GRAD")] + if not self._is_parameter_in_local_shard(param_name): + removed_op_idx.add(idx) + for output_name in op.output_arg_names: + removed_tmp_var.add(output_name) + + for idx, op in reversed(list(enumerate(main_block.ops))): + if not _is_gradient_clip_op(op): + continue + if idx in removed_op_idx: + main_block._remove_op(idx, sync=False) + + for varname in removed_tmp_var: + main_block._remove_var(varname, sync=False) + + main_block._sync_with_cpp() + + def _shard_weight_decay(self, main_block): + + if self.stage < 2: + return + + for idx, op in reversed(list(enumerate(main_block.ops))): + if not _is_weight_decay_op(op): + continue + else: + raise NotImplementedError( + "weight decay is NOT supported by now") + main_block._sync_with_cpp() + + def _shard_optimizer_ops_and_states(self, main_block, startup_block): + + should_removed_optimizer_states = [] + for idx, op in reversed(list(enumerate(main_block.ops))): + if not is_optimizer_op(op): + break + + if op.type in _supported_optimizer_type: + assert "Param" in op.input_names + assert len(op.input("Param")) == 1 + param_name = op.input("Param")[0] + if not self._is_parameter_in_local_shard(param_name): + should_removed_optimizer_states.extend([ + varname for varname in op.output_arg_names + if varname != param_name + ]) + main_block._remove_op(idx, sync=False) + + for idx, op in reversed(list(enumerate(startup_block.ops))): + if len(op.output_arg_names) == 1 and op.output_arg_names[ + 0] in should_removed_optimizer_states: + startup_block._remove_op(idx, sync=False) + + for varname in should_removed_optimizer_states: + if main_block.has_var(varname): + main_block._remove_var(varname, sync=False) + if startup_block.has_var(varname): + startup_block._remove_var(varname, sync=False) + + main_block._sync_with_cpp() + startup_block._sync_with_cpp() + + def _insert_optimizer_broadcasts(self, main_block, startup_block): + + if self.stage > 2: + return + + for sharding_info in self.sharding_infos: + for param in sharding_info.params: + assert main_block.has_var(param.name) + assert startup_block.has_var(param.name) + + new_op = main_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sharding_info.group.id, + 'root': sharding_info.get_var_rank(param.name), + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) + param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + param) + assert param_dist_attr is not None + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_op, param_dist_attr.process_mesh, + param_dist_attr.dims_mapping, self._dist_context) + main_block._sync_with_cpp() + + def _is_parameter_in_local_shard(self, param_name): + assert param_name in self.varname_to_sharding_info + sharding_info = self.varname_to_sharding_info[param_name] + return sharding_info.is_in_local_shard(param_name) + + def _shard_gradient_synchronization(self, main_block): + + if self.stage < 2: + return + + dp_ring_ids = [group.id for group in self.dp_groups] + for idx, op in reversed(list(enumerate(main_block.ops))): + if _is_param_grad_allreduce_op(op, main_block, dp_ring_ids): + input_name = op.input_arg_names[0] + base_name = _get_base_name_from_grad_name(input_name) + sharding_info = self.varname_to_sharding_info[base_name] + _insert_reduce_op( + main_block, idx, input_name, sharding_info.group.id, + sharding_info.get_var_rank(base_name), self._dist_context) + if not self.partial_sharding: + main_block._remove_op(idx + 1, sync=False) + else: + op._set_attr("ring_id", self.outer_dp_group.id) + + main_block._sync_with_cpp() + + def _shard_parameter(self, main_block, startup_block): + + if self.stage < 3: + return + + dp_ring_ids = [group.id for group in self.dp_groups] + for sharding_info in self.sharding_infos: + need_broadcast_vars, param_usage = sharding_info.get_broadcast_vars_and_param_usage( + main_block) + not_used_param_nane = [] + for param_name in param_usage: + if param_usage[param_name] == 0 and sharding_info.get_var_rank( + param_name) != sharding_info.local_rank: + not_used_param_nane.append(param_name) + + for idx, op in reversed(list(enumerate(main_block.ops))): + if is_optimizer_op(op): + continue + + for input_name in op.desc.input_arg_names(): + if op.type == "cast": + continue + if input_name not in need_broadcast_vars: + continue + root_rank = sharding_info.get_var_rank(input_name) + if root_rank == sharding_info.local_rank: + broadcast_varname = input_name + else: + broadcast_varname = unique_name.generate(input_name + + "@BroadCast") + input_var = main_block.var(input_name) + new_var = main_block.create_var( + name=broadcast_varname, + shape=input_var.shape, + dtype=input_var.dtype, + persistable=False) + ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + input_var) + out_var_dist_attr = set_var_dist_attr( + self._dist_context, new_var, + ref_dist_attr.dims_mapping, + ref_dist_attr.process_mesh) + op._rename_input(input_name, broadcast_varname) + + _insert_init_and_broadcast_op( + main_block, idx, broadcast_varname, + sharding_info.local_rank, root_rank, + sharding_info.group.id, + op.attr('op_role'), self._dist_context) + + for idx, op in reversed(list(enumerate(main_block.ops))): + if op.type != "cast": + continue + input_name = op.input_arg_names[0] + output_name = op.output_arg_names[0] + if input_name in not_used_param_nane: + main_block._remove_op(idx, sync=False) + main_block._remove_var(output_name, sync=False) + + for idx, op in reversed(list(enumerate(startup_block.ops))): + assert len(op.output_arg_names) == 1 + output_name = op.output_arg_names[0] + + if op.type == "c_broadcast" and op.attr( + "ring_id") in dp_ring_ids: + if self.outer_dp_group and sharding_info.get_var_rank( + output_name) == sharding_info.local_rank: + op._set_attr("ring_id", self.outer_dp_group.id) + else: + startup_block._remove_op(idx, sync=False) + continue + + if op.type != "c_broadcast" and output_name in not_used_param_nane: + startup_block._remove_op(idx, sync=False) + + for varname in not_used_param_nane: + main_block._remove_var(varname, sync=False) + startup_block._remove_var(varname, sync=False) + + main_block._sync_with_cpp() + startup_block._sync_with_cpp() + + +def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank, + root_rank, ring_id, op_role, dist_context): + """ + empty op for initialization + """ + broadcast_var = block.var(varname) + broadcast_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( + broadcast_var) + + new_op = block._insert_op_without_sync( + insert_idx, + type='c_broadcast', + inputs={'X': varname}, + outputs={'Out': varname}, + attrs={ + 'ring_id': ring_id, + 'root': root_rank, + 'use_calc_stream': True, + OP_ROLE_KEY: op_role + }) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_op, broadcast_var_dist_attr.process_mesh, + broadcast_var_dist_attr.dims_mapping, dist_context) + if local_rank != root_rank: + + new_op = block._insert_op_without_sync( + insert_idx, + type="empty", + outputs={"Out": broadcast_var.name}, + attrs={ + "shape": broadcast_var.shape, + "dtype": broadcast_var.dtype, + OP_ROLE_KEY: op_role + }) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_op, broadcast_var_dist_attr.process_mesh, + broadcast_var_dist_attr.dims_mapping, dist_context) + return + + +def _insert_reduce_op(block, + insert_idx, + reduce_var, + ring_id, + root_id, + dist_context, + op_role=OpRole.Backward, + use_calc_stream=True): + assert root_id >= 0, "root id should be a positive int, but now root id is {}".format( + root_id) + new_op = block._insert_op_without_sync( + insert_idx, + type='c_reduce_sum', + inputs={'X': [reduce_var]}, + outputs={'Out': [reduce_var]}, + attrs={ + 'ring_id': ring_id, + 'root_id': root_id, + 'use_calc_stream': use_calc_stream, + OP_ROLE_KEY: op_role + }) + + dist_attr = dist_context.get_tensor_dist_attr_for_program( + block.var(reduce_var)) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context) + + +def _get_dp_and_sharding_groups(origin_group, sharding_group_size, rank): + dp_axis = 0 + sharding_axis = 1 + shape = [len(origin_group) // sharding_group_size, sharding_group_size] + + dp_group = _get_comm_group(origin_group, shape, dp_axis, rank) + sharding_group = _get_comm_group(origin_group, shape, sharding_axis, rank) + + return dp_group, sharding_group + + +def _is_gradient_clip_op(op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/gradient_clip") + + +def _is_weight_decay_op(op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/regularization") + + +def _is_param_grad_fp32_cast_op(block, op): + if not is_backward_op(op): + return False + if not _is_desired_cast_op(block, op, core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32): + return False + output_name = op.desc.output_arg_names()[0] + base_name = output_name[:output_name.find("@")] + if not block.has_var(base_name): + return False + return block.var(base_name).is_parameter + + +def _is_param_fp16_cast_op(block, op, params): + + if is_optimizer_op(op): + return False + if not _is_desired_cast_op(block, op): + return False + input_name = op.desc.input_arg_names()[0] + if input_name not in params: + return False + return True + + +def _is_desired_cast_op(block, + op, + src_var_type=core.VarDesc.VarType.FP32, + dst_var_type=core.VarDesc.VarType.FP16): + if op.type != "cast": + return False + assert (len(op.desc.input_arg_names()) == 1) + assert (len(op.desc.output_arg_names()) == 1) + input_var = block.var(op.desc.input_arg_names()[0]) + output_var = block.var(op.desc.output_arg_names()[0]) + + if input_var.dtype != src_var_type or \ + output_var.dtype != dst_var_type: + return False + + return True + + +def _get_base_name_from_grad_name(grad_name): + base_name = None + if ".cast_fp16@GRAD" in grad_name: + base_name = grad_name[:grad_name.find(".cast_fp16@GRAD")] + elif "@GRAD" in grad_name: + base_name = grad_name[:grad_name.find("@GRAD")] + return base_name + + +def _is_param_grad_allreduce_op(op, block, dp_ring_ids): + + if not is_backward_op(op): + return False + if op.type != "c_allreduce_sum": + return False + if op.attr('ring_id') not in dp_ring_ids: + return False + + output_name = op.output_arg_names[0] + base_name = _get_base_name_from_grad_name(output_name) + + if not block.has_var(base_name): + return False + + return block.var(base_name).is_parameter + + +def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): + + dp_group = None + for input_name in op.input_arg_names: + if not is_parameter_related(input_name, op.block): + dist_attr = dist_context.get_op_dist_attr_for_program(op) + process_mesh = dist_attr.process_mesh + input_dim_mapping = dist_attr.get_input_dims_mapping(input_name) + mesh_shape = process_mesh.topology + # TODO(JZ-LIANG) replace with specific batch size dimension + batch_size_axis = input_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + group_ranks = _get_comm_group(process_mesh.processes, + process_mesh.topology, + batch_size_axis, rank_id) + dp_group = new_process_group(group_ranks) + break + + return dp_group + + +def shard_parameters(params, group_size): + # TODO(JZ-LIANG) support multiple partition methods + # method1: greedy even but unorder + # method2: roughly even with oreder + mapping = {} + for rank_ in range(group_size): + mapping[rank_] = [] + sizes = [0] * group_size + for param in params: + rank = sizes.index(min(sizes)) + mapping[rank].append(param) + numel = reduce(lambda x, y: x * y, param.shape) + assert numel > 0, "param [{}] should larger than 0, but it is [{}]".format( + param.name, numel) + sizes[rank] += numel + + return mapping + + +class ShardingInfo(object): + def __init__(self, group, rank, params): + self.group = group + self.params = params + self.param_names = [p.name for p in self.params] + self.group_size = group.nranks + self.global_rank = rank + self.local_rank = group.ranks.index(self.global_rank) + # rank in below mapping are local rank in this sharding group + self.rank_to_params = shard_parameters(self.params, self.group_size) + # include fp32 and fp16 param + self.param_to_rank = dict() + self._map_param_to_rank() + + def _map_param_to_rank(self): + """ + mapping parameters to the rank which holds it. + """ + for rank, params in self.rank_to_params.items(): + for param in params: + self.param_to_rank[param.name] = rank + + def get_var_rank(self, varname): + if varname in self.param_to_rank: + return self.param_to_rank[varname] + return -1 + + def is_in_local_shard(self, param_name): + return self.get_var_rank(param_name) == self.local_rank + + def get_broadcast_vars_and_param_usage(self, block): + broadcast_vars = set([]) + fp16_params = set([]) + fp16_to_fp32 = {} + + param_usage = {x: 0 for x in self.param_names} + for op in block.ops: + if is_optimizer_op(op): + continue + for input_name in op.desc.input_arg_names(): + if input_name in self.param_names: + param_usage[input_name] += 1 + + for op in block.ops: + if not _is_param_fp16_cast_op(block, op, self.param_names): + continue + input_name = op.input_arg_names[0] + output_name = op.output_arg_names[0] + broadcast_vars.add(output_name) + fp16_params.add(output_name) + fp16_to_fp32[output_name] = input_name + param_usage[input_name] -= 1 + self.param_to_rank[output_name] = self.param_to_rank[input_name] + + for param, usage in param_usage.items(): + if usage > 0: + broadcast_vars.add(param) + return broadcast_vars, param_usage diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt old mode 100644 new mode 100755 index e9146b68a9..48a9f7204a --- a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt @@ -5,4 +5,5 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) list(APPEND DIST_TEST_OPS ${TEST_OP}) set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 90) + set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST") endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index 2689d7d945..f5eda2fdbf 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -63,12 +63,12 @@ class AutoPallelPassTestBase(DistPassTestBase): def check_main(self, gpus=None, **kwargs): no_pass_rets = self._distributed_launch( - apply_pass=False, gpus=gpus, **kwargs) + model=None, apply_pass=False, gpus=gpus, **kwargs) pass_rets = self._distributed_launch( - apply_pass=True, gpus=gpus, **kwargs) + model=None, apply_pass=True, gpus=gpus, **kwargs) self.check_results(no_pass_rets, pass_rets) - def _run_gpu_main(self, apply_pass, dump_file, **kwargs): + def _run_gpu_main(self, model, apply_pass, dump_file, **kwargs): gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) place = paddle.CUDAPlace(gpu_id) scope = paddle.static.Scope() @@ -82,8 +82,8 @@ class AutoPallelPassTestBase(DistPassTestBase): with paddle.fluid.unique_name.guard(): main_prog, startup_prog, inputs, outputs, reader = self.get_model( place, **kwargs) - inputs = self._to_var_names(main_prog, inputs) - outputs = self._to_var_names(main_prog, outputs) + inputs = self._to_var_names(inputs) + outputs = self._to_var_names(outputs) all_fetch_values = [] exe = paddle.static.Executor(place) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py new file mode 100644 index 0000000000..f6b42701c2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import random +import numpy as np + +import unittest +import paddle +import paddle.nn as nn +import paddle.distributed.fleet as fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.passes import new_pass, PassManager +from auto_parallel_pass_test_base import AutoPallelPassTestBase +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion + + +class TestShardingPass(AutoPallelPassTestBase): + def init(self): + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + self.rtol = 1e-5 + self.atol = 1e-8 + + rank = paddle.distributed.get_rank() + paddle.seed(rank + 2021) + random.seed(rank + 2021) + np.random.seed(rank + 2021) + + def apply_passes(self): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + dist_strategy.sharding = True + dist_strategy.sharding_configs = { + "sharding_degree": 2, + "stage": 3, + } + fleet.init(is_collective=True, strategy=dist_strategy) + + def apply_no_passes(self): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.pipeline = False + dist_strategy.recompute = False + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + + def test_bs_8(self): + self.check_main( + gpus=[0, 1], batch_size=8, sequence_len=512, vocab_size=1000) + + def get_model(self, place, batch_size, sequence_len, vocab_size): + return self.get_gpt_model('dp', place, batch_size, sequence_len, + vocab_size) + + +if __name__ == "__main__": + unittest.main() -- GitLab