diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 269a0ec644dbd2e54ab255fa88f2e80ae744985f..835ca68df2d1c1169b9eaca767715070ca00e013 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -19,5 +19,6 @@ from .interface import shard_tensor from .interface import shard_op from .interface import recompute from .interface import fetch +from .random import parallel_manual_seed __all__ = [] diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 5df89a277dfa4bba5da4f20f45a4add3bc9c6858..7d988c6c95ed85fe84f67a243571ca270ed1de4b 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -51,7 +51,7 @@ from .dist_loader import ( from .process_group import new_process_group, get_all_process_groups from .dist_context import DistributedContext, get_default_distributed_context from .strategy import Strategy -from .interface import CollectionNames, get_collection +from .interface import CollectionNames, get_collection, fetch from .utils import to_list, get_dist_attr, get_lr, validate_opt from .utils import initialize_pg_in_full_mode, get_input_split_info from .cost.estimate_cost import get_cost_from_engine @@ -438,6 +438,8 @@ class Engine: ), "user_fetches must be a list, but receive {}".format( type(user_fetches).__name__ ) + else: + user_fetches = [] fetch_names = [] fetch_indices = [] @@ -464,10 +466,13 @@ class Engine: _process_fetch_group("metrics_" + str(i), var_list) if mode == "predict": _process_fetch_group("outputs", fetch_vars["outputs"]) + for usr_fetch in user_fetches: + var_name = _to_name_str(usr_fetch) + fetch(var_name) user_fetches_collection = [ item[1] for item in get_collection(CollectionNames.FETCHES) ] - var_list = (user_fetches_collection or []) + (user_fetches or []) + var_list = user_fetches_collection or [] _process_fetch_group("fetches", var_list) return fetch_names, fetch_indices @@ -522,10 +527,10 @@ class Engine: # logging user fetches collect_fetches = get_collection(CollectionNames.FETCHES) logs_fetch = {} - for name, var in collect_fetches: - if var.name in fetch_names: - idx = fetch_names.index(var.name) - logs_fetch[name or var.name] = outs[idx] + for name, var_name in collect_fetches: + if var_name in fetch_names: + idx = fetch_names.index(var_name) + logs_fetch[name or var_name] = outs[idx] logs["fetches"] = logs_fetch return logs diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index b154209700a313632a9eb3cdb6be9d8e0a2d35e2..f97d7b39f05d0ab91a8cf088e92069e2e3d58643 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -258,6 +258,16 @@ def add_to_collection(collection_name, value, name=None): def fetch(tensor, name=None, logging=False): + if isinstance(tensor, paddle.fluid.framework.Variable): + tensor = tensor.name + elif isinstance(tensor, str): + tensor = tensor + else: + raise TypeError( + "Only support fetch `Variable` or `str`[`Variable`'s name], but got `{}`".format( + type(tensor) + ) + ) add_to_collection(CollectionNames.FETCHES, tensor, name) if logging: add_to_collection(CollectionNames.LOGGING, tensor, name) diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 406ec4d8b36da07f56d1883dea77116f2ca6d165..bc5bf4b7379e7259ea716fb017b8b232d1554e06 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -36,3 +36,4 @@ from . import dist_reduce_sum_p from . import dist_shape from . import dist_assign from . import dist_scale +from . import dist_dropout diff --git a/python/paddle/distributed/auto_parallel/operators/dist_dropout.py b/python/paddle/distributed/auto_parallel/operators/dist_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbd133ebbe3c289c153fa87808cd27e13d7c024 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_dropout.py @@ -0,0 +1,186 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import logging + +import paddle +from paddle.framework import core +from paddle.utils import unique_name + +from ...utils.log_utils import get_logger + +_logger = get_logger(logging.INFO) +from ..random import determinate_rng, is_enable_auto_rand_ctrl +from ..utils import ( + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_var_dist_attr, +) +from .common import ( + DistributedOperatorImplContainer, + register_distributed_operator_impl, + register_distributed_operator_impl_container, +) +from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0 + + +class DistributedDropout(DistributedOperatorImplContainer): + def __init__(self, op_type): + super().__init__(op_type) + + +register_distributed_operator_impl_container(DistributedDropout("dropout")) + + +# Dist Dropout with Random Control +# Dropout re-use the compatible and cost function of elementwise +class DistributedDropoutImpl0(DistributedElementwiseImpl0): + def __init__(self, name): + super().__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + + if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute: + + assert ( + op_dist_attr is not None + ), f"forward op [{str(src_op)}] don't have dist attribute !" + + # check validation of inputs / outputs + assert 'X' in kwargs, "input [{}] is not given".format('X') + assert ( + len(kwargs['X']) == 1 + ), "input X should be only one tensor but got {}".format( + kwargs['X'] + ) + assert 'Seed' in kwargs, "input [{}] is not given".format('Seed') + + if ( + src_op.has_attr("fix_seed") + and src_op.attr("fix_seed") + and src_op.has_attr("seed") + and src_op.attr("seed") + ): + _logger.info( + "Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format( + src_op + ) + ) + elif rank_id not in op_dist_attr.process_mesh.process_ids: + pass + # NOTE Adopt for recompute + # If user already set seed, We should not modify it. But if the seed is added by recompute pass, it should be under control. + # TODO in future recompute pass should happen after parallel partitione. and remove this at that time. + elif len(kwargs['Seed']) > 0 or len(src_op.input("Seed")) > 0: + seed_var_name = kwargs['Seed'][0] + if seed_var_name.startswith('rc_seed'): + pre_op = main_block.ops[-1] + assert ( + pre_op.type == "seed" + and len(pre_op.attr("rng_name")) == 0 + ), f"found exception op {str(pre_op)}" + + # determinate rng + X_var = main_block._var_recursive(kwargs['X'][0]) + X_dims_mapping = op_dist_attr.get_input_dims_mapping( + X_var.name + ) + process_mesh = op_dist_attr.process_mesh + rng_name = determinate_rng( + rank_id, X_dims_mapping, process_mesh + ) + # make recompute seed under control + pre_op._set_attr("rng_name", rng_name) + pre_op._set_attr("deterministic", True) + pre_op._set_attr("force_cpu", True) + else: + _logger.info( + "Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format( + src_op + ) + ) + else: + # determinate rng + X_var = main_block._var_recursive(kwargs['X'][0]) + X_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name) + process_mesh = op_dist_attr.process_mesh + + rng_name = determinate_rng( + rank_id, X_dims_mapping, process_mesh + ) + assert rng_name is not None and rng_name != "" + + # insert seed op + seed_var = main_block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join(["tensor_parallel_seed", 'tmp']) + ), + dtype=paddle.int32, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False, + ) + + # set new seed_var's dist_attr + seed_var_dims_mapping = [-1] + seed_var_dist_attr = set_var_dist_attr( + ctx, seed_var, seed_var_dims_mapping, process_mesh + ) + + # adopt for recompute + # force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang + seed_op = main_block.append_op( + type='seed', + outputs={'Out': seed_var}, + attrs={ + 'deterministic': True, + 'rng_name': rng_name, + 'force_cpu': True, + }, + ) + seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed') + # set new seed op's dist_attr + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + seed_op, process_mesh, seed_var_dims_mapping, ctx + ) + + # modify dropout op + src_op.desc.set_input("Seed", [seed_var.name]) + src_op._remove_attr("fix_seed") + src_op._remove_attr("seed") + op_dist_attr.set_input_dist_attr( + seed_var.name, seed_var_dist_attr + ) + kwargs['Seed'] = [seed_var.name] + + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + # dropout backward is deterministic by mask, and not need for random state control + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl( + "dropout", DistributedDropoutImpl0("random_control") +) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 856d9c36bb4e17d7354918a065aaf745df39ec3f..26bed30871ce8420cf7bcf080337ae60416bbcbe 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -17,47 +17,74 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import gradient_synchronization -from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related +from .common import ( + register_distributed_operator_impl, + set_comm_op_dist_attr_for_program, + naive_copy_op_dist_attr_for_program, + is_parameter_related, +) from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping -from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from ..dist_attribute import ( + OperatorDistributedAttribute, + TensorDistributedAttribute, +) from paddle.fluid import core, unique_name from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import Program, Parameter, Variable from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype -from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from paddle.distributed.fleet.meta_optimizers.common import ( + OpRole, + OP_ROLE_KEY, + OP_ROLE_VAR_KEY, +) from ..process_group import new_process_group -from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank, set_var_dist_attr +from ..utils import ( + _get_comm_group, + _get_idx_in_axis, + _get_corresponding_rank, + set_var_dist_attr, +) from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op -from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs +from ..cost import ( + build_comm_costs_from_descs, + build_comp_costs_from_descs, + build_dp_costs, +) from ..cost import EmbeddingOpCost, EmbeddingGradOpCost -from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost +from paddle.distributed.auto_parallel.cost.comm_op_cost import ( + AllreduceSumOpCost, + IdentityOpCost, +) class DistributedEmbedding(DistributedOperatorImplContainer): - def __init__(self, op_type): super(DistributedEmbedding, self).__init__(op_type) register_distributed_operator_impl_container( - DistributedEmbedding("lookup_table_v2")) + DistributedEmbedding("lookup_table_v2") +) register_distributed_operator_impl_container( - DistributedEmbedding("c_embedding")) + DistributedEmbedding("c_embedding") +) register_distributed_operator_impl_container( - DistributedEmbedding("lookup_table")) + DistributedEmbedding("lookup_table") +) def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): - assert len( - Ids_var.shape - ) == 3, "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format( - Ids_var.name, Ids_var.shape) + assert ( + len(Ids_var.shape) == 3 + ), "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format( + Ids_var.name, Ids_var.shape + ) if not Ids_var.stop_gradient: raise NotImplementedError( 'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).' @@ -65,59 +92,72 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): target_shape = list(Ids_var.shape[:-1]) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["dist_reshape", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["dist_reshape", 'tmp']) + ), dtype=Ids_var.dtype, shape=target_shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=True) + stop_gradient=True, + ) target_shape = [0] + list(Ids_var.shape[:-1]) xshape_var = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["dist_Xshape", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["dist_Xshape", 'tmp']) + ), dtype=Ids_var.dtype, shape=target_shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=True) + stop_gradient=True, + ) # TODO use inplace reshape for memory saving - reshape_op = main_block.append_op(type='reshape2', - inputs={'X': [Ids_var]}, - outputs={ - 'Out': [intermediate_var_0], - 'XShape': [xshape_var] - }, - attrs={ - "shape": [0, -1], - }) + reshape_op = main_block.append_op( + type='reshape2', + inputs={'X': [Ids_var]}, + outputs={'Out': [intermediate_var_0], 'XShape': [xshape_var]}, + attrs={ + "shape": [0, -1], + }, + ) # set dist attr op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name) assert Ids_var_dist_attr is not None intermediate_var_0_dist_attr = set_var_dist_attr( - ctx, intermediate_var_0, Ids_var_dist_attr.dims_mapping, - Ids_var_dist_attr.process_mesh) - set_var_dist_attr(ctx, xshape_var, - [-1] + list(Ids_var_dist_attr.dims_mapping), - Ids_var_dist_attr.process_mesh) + ctx, + intermediate_var_0, + Ids_var_dist_attr.dims_mapping, + Ids_var_dist_attr.process_mesh, + ) + set_var_dist_attr( + ctx, + xshape_var, + [-1] + list(Ids_var_dist_attr.dims_mapping), + Ids_var_dist_attr.process_mesh, + ) op_dist_attr.del_input_dist_attr(Ids_var.name) - op_dist_attr.set_input_dist_attr(intermediate_var_0.name, - intermediate_var_0_dist_attr) + op_dist_attr.set_input_dist_attr( + intermediate_var_0.name, intermediate_var_0_dist_attr + ) new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh new_op_dist_attr.impl_type = "default" new_op_dist_attr.impl_idx = 0 - new_op_dist_attr.set_input_dims_mapping(Ids_var.name, - Ids_var_dist_attr.dims_mapping) - new_op_dist_attr.set_output_dims_mapping(intermediate_var_0.name, - Ids_var_dist_attr.dims_mapping) + new_op_dist_attr.set_input_dims_mapping( + Ids_var.name, Ids_var_dist_attr.dims_mapping + ) + new_op_dist_attr.set_output_dims_mapping( + intermediate_var_0.name, Ids_var_dist_attr.dims_mapping + ) new_op_dist_attr.set_output_dims_mapping( - xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)) + xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping) + ) ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr) return intermediate_var_0 @@ -125,7 +165,6 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): # RowParallel class DistributedEmbeddingImpl(DistributedOperatorImpl): - def __init__(self, name): super(DistributedEmbeddingImpl, self).__init__(name) self._forward_implemented = True @@ -143,17 +182,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes # embedding need start_index - cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + EmbeddingOpCost, ctx, processes, desc_mapping, cluster + ) serial_op = dist_op.serial_op parallel_axis = dist_op.dist_attr.get_input_dims_mapping( - serial_op.input("W")[0])[0] + serial_op.input("W")[0] + )[0] attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = serial_op.output("Out") c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( @@ -162,11 +203,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, - cluster) + AllreduceSumOpCost, + ctx, + processes, + c_allreduce_sum_desc_mapping, + cluster, + ) res_cost = [cost_mapping, comm_op_cost_list] @@ -180,7 +226,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): dist_attr = dist_op.dist_attr embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("W")[0])[0] + backward_op.input("W")[0] + )[0] parallel_axis = embedding_row_dim_mapping attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = [backward_op.input("Out@GRAD")[0]] @@ -190,33 +237,38 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) process_mesh = dist_attr.process_mesh processes = process_mesh.processes comm_op_cost_list = build_comm_costs_from_descs( - IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster + ) res.append(comm_op_cost_list) # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) - cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx, - processes, desc_mapping, - cluster) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) + cost_mapping = build_comp_costs_from_descs( + EmbeddingGradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("Ids")[0]) + backward_op.input("Ids")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('W@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res @@ -228,7 +280,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard( - w_dims_mapping[-1]): + w_dims_mapping[-1] + ): return False # Other dimensions must be replicate except the batch dimension for mapping in ids_dims_mapping[1:]: @@ -248,8 +301,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False op_desc = dist_op.serial_op.desc @@ -261,7 +315,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) - if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]: + if ids_dims_mapping != out_dims_mapping[: len(ids_dims_mapping)]: return False return True @@ -279,12 +333,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): for i in range(len(ids_dims_mapping)): dim_changed = compute_compatible_and_update_dim_mapping( - [ids_dims_mapping, out_dims_mapping], [i, i]) + [ids_dims_mapping, out_dims_mapping], [i, i] + ) if dim_changed: changed = True dim_changed = compute_compatible_and_update_dim_mapping( - [w_dims_mapping, out_dims_mapping], [-1, -1]) + [w_dims_mapping, out_dims_mapping], [-1, -1] + ) if dim_changed: changed = True @@ -302,26 +358,30 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(src_op)) + assert ( + op_dist_attr is not None + ), "forward op [{}] don't have dist attribute !".format(str(src_op)) # check validation of inputs / outputs assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'Out' in kwargs, "output [{}] is not given".format('Out') - assert len( + assert ( + len(kwargs['Ids']) == 1 + ), "row_parallel_embedding input Ids take 1 variable but got {}".format( kwargs['Ids'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Ids']) - assert len( + ) + assert ( + len(kwargs['W']) == 1 + ), "row_parallel_embedding input W take 1 variable but got {}".format( kwargs['W'] - ) == 1, "row_parallel_embedding input W take 1 variable but got {}".format( - kwargs['W']) - assert len( + ) + assert ( + len(kwargs['Out']) == 1 + ), "row_parallel_embedding output Out take 1 variable but got {}".format( kwargs['Out'] - ) == 1, "row_parallel_embedding output Out take 1 variable but got {}".format( - kwargs['Out']) + ) Ids_var = main_block.var(kwargs['Ids'][0]) Weight_var = main_block._var_recursive(kwargs['W'][0]) @@ -333,70 +393,85 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): # got dist attribute info embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[0] - assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( - embedding_row_dim_mapping) + Weight_var.name + )[0] + assert ( + embedding_row_dim_mapping >= 0 + ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( + embedding_row_dim_mapping + ) process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in process_mesh_group: - rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, op_dist_attr.process_mesh, rank_id + ) # A generalized method to caculate embedding offset using cartisian product - relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape, - embedding_row_dim_mapping, rank_id) + relative_idx = _get_idx_in_axis( + process_mesh_group, + process_mesh_shape, + embedding_row_dim_mapping, + rank_id, + ) per_part_size = Weight_var.shape[0] relative_idx = relative_idx * per_part_size # TODO caculate ring id parallel_axis = embedding_row_dim_mapping - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - parallel_axis, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) # append op - check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], - 'c_embedding') + check_variable_and_dtype( + Ids_var, 'input', ['int32', 'int64'], 'c_embedding' + ) # infer new var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, - out_var_dist_attr) + ref_shape = infer_shape( + main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_embedding", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_embedding", 'tmp']) + ), dtype=Weight_var.dtype, shape=Out_var.shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=Out_var.stop_gradient) + stop_gradient=Out_var.stop_gradient, + ) # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - out_var_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, out_var_dist_attr + ) check_variable_and_dtype( - Out_var, 'tensor', + Out_var, + 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - 'c_allreduce_sum') + 'c_allreduce_sum', + ) c_embedding_op = main_block.append_op( type='c_embedding', - inputs={ - 'Ids': [Ids_var], - 'W': [Weight_var] - }, + inputs={'Ids': [Ids_var], 'W': [Weight_var]}, outputs={'Out': [intermediate_var_0]}, attrs={ "start_index": relative_idx, - OP_ROLE_KEY: src_op.attr('op_role') - }) + OP_ROLE_KEY: src_op.attr('op_role'), + }, + ) if intermediate_var_0.shape != ref_shape: intermediate_var_0.desc.set_shape(ref_shape) @@ -409,8 +484,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, - OP_ROLE_KEY: src_op.attr('op_role') - }) + OP_ROLE_KEY: src_op.attr('op_role'), + }, + ) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -423,15 +499,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): for input_varname in c_embedding_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - embedding_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + embedding_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) output_varname = c_embedding_op.desc.output_arg_names()[0] output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - embedding_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) + op_dist_attr + ) + embedding_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr) # allreduce @@ -443,16 +523,20 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): input_var = main_block.var(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None - allreduce_op_dist_attr.set_input_dist_attr(input_varname, - tensor_dist_attr) + allreduce_op_dist_attr.set_input_dist_attr( + input_varname, tensor_dist_attr + ) for output_varname in c_allreduce_sum_op.desc.output_arg_names(): output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - allreduce_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) - ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, - allreduce_op_dist_attr) + op_dist_attr + ) + allreduce_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) + ctx.set_op_dist_attr_for_program( + c_allreduce_sum_op, allreduce_op_dist_attr + ) # param initialization sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: @@ -469,20 +553,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): if size <= 1 or axis in dim_mapping: pass else: - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, axis, - rank_id) + group_ranks = _get_comm_group( + process_mesh.processes, + process_mesh.topology, + axis, + rank_id, + ) sync_group = new_process_group(group_ranks) - startup_block.append_op(type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': sync_group.id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) + startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }, + ) @staticmethod def backward(ctx, *args, **kwargs): @@ -493,35 +582,43 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): backward_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(backward_op)) + assert ( + dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format( + str(backward_op) + ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in dist_attr.process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, dist_attr.process_mesh, rank_id + ) assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out') assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD') - assert len( + assert ( + len(kwargs['Ids']) == 1 + ), "row_parallel_embedding input Ids take 1 variable but got {}".format( kwargs['Ids'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Ids']) - assert len( + ) + assert ( + len(kwargs['W']) == 1 + ), "row_parallel_embedding input Ids take 1 variable but got {}".format( kwargs['W'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['W']) - assert len( - kwargs['Out@GRAD'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Out']) - assert len( + ) + assert ( + len(kwargs['Out@GRAD']) == 1 + ), "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Out'] + ) + assert ( + len(kwargs['W@GRAD']) == 1 + ), "row_parallel_embedding output Ids take 1 variable but got {}".format( kwargs['W@GRAD'] - ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( - kwargs['W@GRAD']) + ) Ids_var = main_block.var(kwargs['Ids'][0]) Weight_var = main_block.var(kwargs['W'][0]) @@ -529,39 +626,57 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): Weight_grad = main_block.var(kwargs['W@GRAD'][0]) embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( - Weight_var.name)[0] - assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( - embedding_row_dim_mapping) + Weight_var.name + )[0] + assert ( + embedding_row_dim_mapping >= 0 + ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( + embedding_row_dim_mapping + ) process_mesh_shape = dist_attr.process_mesh.topology process_mesh_group = dist_attr.process_mesh.processes # A generalized method to caculate embedding offset using cartisian product - relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape, - embedding_row_dim_mapping, rank_id) + relative_idx = _get_idx_in_axis( + process_mesh_group, + process_mesh_shape, + embedding_row_dim_mapping, + rank_id, + ) per_part_size = Weight_var.shape[0] relative_idx = relative_idx * per_part_size check_variable_and_dtype( - Out_grad, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + Out_grad, + 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_identity', + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_embedding", '@tmp_0@GRAD'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_embedding", '@tmp_0@GRAD']) + ), dtype=Out_grad.dtype, shape=Out_grad.shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=Out_grad.stop_gradient) + stop_gradient=Out_grad.stop_gradient, + ) # copy X_var's dist_attr to intermediate_var_0's dist_attr out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name) assert out_grad_dist_attr is not None - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - out_grad_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, out_grad_dist_attr + ) - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - embedding_row_dim_mapping, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, + process_mesh_shape, + embedding_row_dim_mapping, + rank_id, + ) group = new_process_group(group_ranks) c_identity_op = main_block.append_op( @@ -573,41 +688,54 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'use_calc_stream': True, 'use_model_parallel': True, OP_ROLE_KEY: OpRole.Backward, - }) - check_variable_and_dtype(intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], 'linear') - check_dtype(intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') + }, + ) + check_variable_and_dtype( + intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' + ) + check_dtype( + intermediate_var_0.dtype, + 'dtype', + ['float16', 'float32', 'float64'], + 'linear', + ) - set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh, - out_grad_dist_attr, ctx) + set_comm_op_dist_attr_for_program( + c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx + ) c_embedding_grad_op_desc = main_block.append_op(type='nop').desc c_embedding_grad_op_desc.set_type("c_embedding_grad") c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name]) c_embedding_grad_op_desc.set_input('W', [Weight_var.name]) - c_embedding_grad_op_desc.set_input('Out@GRAD', - [intermediate_var_0.name]) + c_embedding_grad_op_desc.set_input( + 'Out@GRAD', [intermediate_var_0.name] + ) c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name]) c_embedding_grad_op_desc._set_attr('start_index', relative_idx) c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) c_embedding_grad_op = main_block.ops[-1] assert c_embedding_grad_op.type == "c_embedding_grad" - naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op, - ctx) + naive_copy_op_dist_attr_for_program( + c_embedding_grad_op, backward_op, ctx + ) # data parallel gradient synchronization act_grad_names = [Ids_var.name] out_grad_names = [kwargs['W@GRAD'][0]] - gradient_synchronization(ctx, backward_op, act_grad_names, - out_grad_names, rank_id) + gradient_synchronization( + ctx, backward_op, act_grad_names, out_grad_names, rank_id + ) -register_distributed_operator_impl("lookup_table_v2", - DistributedEmbeddingImpl("row_parallel")) -register_distributed_operator_impl("c_embedding", - DistributedEmbeddingImpl("row_parallel")) -register_distributed_operator_impl("lookup_table", - DistributedEmbeddingImpl("row_parallel")) +register_distributed_operator_impl( + "lookup_table_v2", DistributedEmbeddingImpl("row_parallel") +) +register_distributed_operator_impl( + "c_embedding", DistributedEmbeddingImpl("row_parallel") +) +register_distributed_operator_impl( + "lookup_table", DistributedEmbeddingImpl("row_parallel") +) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 1e30467c4f722779f03d22a2a37a467190dbffb3..6b997d888a481f0932f66f3739a560b90fa6d1fd 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -25,6 +25,7 @@ from .reshard import Resharder from .partitioner import Partitioner from .utils import set_grad_var_shape from .process_group import get_world_process_group +from .random import init_auto_parallel_rng from ..utils.log_utils import get_logger @@ -84,6 +85,9 @@ class Parallelizer: ) = partitioner.partition( serial_main_program, serial_startup_program, params_grads ) + + init_auto_parallel_rng() + self._logger.debug( "within parallel partitioner time: {}, mode {}".format( time.time() - time0, self._mode diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 72dc9043cabd6ab3524019a9e5abe805fdf15b97..4d6df30bc32cd4463cdecac5dd410436d4f0849f 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -19,6 +19,8 @@ import paddle # Use to store the previous and current process mesh _g_previous_process_mesh = None _g_current_process_mesh = None +# {shape_process_ids : unique_id} +_g_unique_process_mesh_map = {} def get_current_process_mesh(): @@ -39,6 +41,30 @@ def reset_current_process_mesh(): _g_current_process_mesh = _g_previous_process_mesh +def get_unique_id_for_process_mesh(shape, process_ids): + key = f"shape {shape}, process_ids {process_ids}" + global _g_unique_process_mesh_map + if key in _g_unique_process_mesh_map: + unique_id = _g_unique_process_mesh_map[key] + else: + unique_id = len(_g_unique_process_mesh_map) + 1 + _g_unique_process_mesh_map[key] = unique_id + + return unique_id + + +def retrive_unique_id_for_process_mesh(shape, process_ids): + key = f"shape {shape}, process_ids {process_ids}" + global _g_unique_process_mesh_map + assert key in _g_unique_process_mesh_map + return _g_unique_process_mesh_map[key] + + +def get_unique_process_mesh_map(): + global _g_unique_process_mesh_map + return _g_unique_process_mesh_map + + class ProcessMesh(object): """ The `Processmesh` object describes the topology of the used processes. @@ -113,6 +139,11 @@ class ProcessMesh(object): pg0 = get_process_group(0) pg0.add_ranks(self.processes) + # Uniqe Mesh Id + self._unique_id = get_unique_id_for_process_mesh( + self._shape, self._process_ids + ) + @property def shape(self): """ @@ -148,6 +179,16 @@ class ProcessMesh(object): """ return self._mesh + @property + def unique_id(self): + """ + Get the unique id of ProcessMesh. + NOTE + Unique id only take process_ids and shape into account. + Different ProcessMesh with same process_ids and shape have same unique id. + """ + return self._unique_id + @property def topology(self): return self._shape diff --git a/python/paddle/distributed/auto_parallel/random.py b/python/paddle/distributed/auto_parallel/random.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca6d9e9ea06961e67b82b6962f241dfcc8ce64e --- /dev/null +++ b/python/paddle/distributed/auto_parallel/random.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import paddle + +from ..utils.log_utils import get_logger +from .process_mesh import retrive_unique_id_for_process_mesh +from .utils import _get_idx_in_axis + +_logger = get_logger(logging.INFO) + +_rng_name_to_seed = {} +_inited_rng_name_to_seed = {} +_enable_random_control = False +_basic_seed = 42 + +# use Prime number as offset to avoid confict +_mesh_offset = 173 +_dim_offsets = [11, 23, 37, 73] + + +def is_enable_auto_rand_ctrl(): + global _enable_random_control + return _enable_random_control + + +def enable_auto_rand_ctrl(): + global _enable_random_control + _enable_random_control = True + + +def parallel_manual_seed(seed): + """Enable auto parallel random control. + Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order). + * Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness. + + * Consistency: Meanwhile if the tensor is **Replicated** on another Mesh dimension, randomness of Devices along that Mesh dimension should be Consistent. + + For instance: rank0 ~ rank7 consist a Mesh of shape of [2, 4]; A 2D tensor is distributed in that Mesh using dims_mapping [-1, 1]. + Randomness for rank0-rank1-rank2-rank3 (rank4-rank5-rank6-rank7) should be Independent; + Randomness for rank0 and rank4 (rank1 and rank5, ...) should be Consistent. + + This function should be called only once before auto parallel compiles the computation graph (e.g. auto_parallel.engine.prepare() or fit()). + + This seed only affects how randomness-relative **operators** (dropout, fuse op with dropout inside, etc) are execute amonge mesh, and would NOT affect other processe like Parameter initialization. + + Examples: + # seed relative to training step + auto_parallel_random_seed((step + 13) * 257) + ... + engine.prepare() + """ + + enable_auto_rand_ctrl() + global _basic_seed + _basic_seed = seed + + +def determinate_rng(rank, dims_mapping, process_mesh): + + # TODO(JZ-LIANG) Support Mesh with any high rank + # use a string to unique integer hashing algorithm for seed computation. + # instead of using offsets to coodinate seed across devices. + if len(process_mesh.shape) > 4: + raise NotImplementedError( + "Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {}".format( + str(process_mesh) + ) + ) + global _basic_seed + seed_ = _basic_seed + + # FIXME + # unique_id = process_mesh.unique_id + unique_id = retrive_unique_id_for_process_mesh( + process_mesh.shape, process_mesh.process_ids + ) + sharding_expr = f'mesh:{unique_id}' + seed_ += _mesh_offset * (unique_id + 1) + + for i in range(len(process_mesh.shape)): + if i not in dims_mapping: + relative_idx = -1 + else: + relative_idx = _get_idx_in_axis( + process_mesh.process_ids, + process_mesh.shape, + i, + rank, + ) + + sharding_expr += f"_dim{i}:{relative_idx}" + seed_ += _dim_offsets[i] * (relative_idx + 1) + + global _rng_name_to_seed + if sharding_expr in _rng_name_to_seed: + assert _rng_name_to_seed[sharding_expr] == seed_ + else: + assert ( + seed_ not in _rng_name_to_seed.values() + ), "Seed Confilt! current seed: {}, current sharding expr: {}, generated seed: {}".format( + seed_, sharding_expr, _rng_name_to_seed + ) + _rng_name_to_seed[sharding_expr] = seed_ + + return sharding_expr + + +def init_auto_parallel_rng(): + + if not is_enable_auto_rand_ctrl(): + return + + global _rng_name_to_seed + # NOTE init rng maybe call multiple times, avoid init same rng twice + global _inited_rng_name_to_seed + + for rng_name, seed in _rng_name_to_seed.items(): + if rng_name in _inited_rng_name_to_seed: + assert _inited_rng_name_to_seed[rng_name] == seed + else: + _logger.info( + f"Init Auto Parallel RNG: {rng_name}, with seed {seed}" + ) + paddle.framework.random.set_random_seed_generator(rng_name, seed) + _inited_rng_name_to_seed[rng_name] = seed diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index a9c83a98c19fcb3283011603d7fc65101e79022a..74b142ab7e134d801830e8797c6af066e80fa259 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -22,13 +22,20 @@ from paddle.fluid.framework import Variable, Operator from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_ from paddle.distributed.auto_parallel.process_mesh import ProcessMesh -from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute -from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr, set_dist_op_desc_original_id -from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping +from paddle.distributed.auto_parallel.dist_attribute import ( + OperatorDistributedAttribute, +) +from paddle.distributed.auto_parallel.utils import ( + get_loss_op, + set_var_dist_attr, + set_dist_op_desc_original_id, +) +from paddle.distributed.auto_parallel.utils import ( + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, +) class RecomputeState(ProgramStats): - def __init__(self, block, ops): super(RecomputeState, self).__init__(block=block, ops=ops) self._block = block @@ -54,7 +61,7 @@ class RecomputeState(ProgramStats): self.var_op_deps[name]["var_as_output_ops"] = [i] def get_recompute_segments(self, checkpoints): - """ get recompute segments from checkpoints """ + """get recompute segments from checkpoints""" segments = [] start_idx = -1 pre_segment_end_idx = -1 @@ -69,33 +76,43 @@ class RecomputeState(ProgramStats): segments.append([0, max(op_idx_list) + 1]) else: flag, min_idx, max_idx = self.is_subgraph( - [checkpoints[start_idx]], [checkpoints[start_idx + 1]]) + [checkpoints[start_idx]], [checkpoints[start_idx + 1]] + ) if flag: min_idx = self._update_segment_start( - min_idx, pre_segment_end_idx) + min_idx, pre_segment_end_idx + ) segments.append([min_idx, max_idx + 1]) else: logging.info( "Could not recompute op range [{}] - [{}] ".format( - min_idx, max_idx + 1)) + min_idx, max_idx + 1 + ) + ) start_idx += 1 for i, (idx1, idx2) in enumerate(segments): logging.info("recompute segment[{}]".format(i)) - logging.info("segment start op: [{}]: [{}] [{}]".format( - self._ops[idx1].desc.type(), - self._ops[idx1].desc.input_arg_names(), - self._ops[idx1].desc.output_arg_names())) - logging.info("segment end op: [{}]: [{}] [{}]".format( - self._ops[idx2 - 1].desc.type(), - self._ops[idx2 - 1].desc.input_arg_names(), - self._ops[idx2 - 1].desc.output_arg_names())) + logging.info( + "segment start op: [{}]: [{}] [{}]".format( + self._ops[idx1].desc.type(), + self._ops[idx1].desc.input_arg_names(), + self._ops[idx1].desc.output_arg_names(), + ) + ) + logging.info( + "segment end op: [{}]: [{}] [{}]".format( + self._ops[idx2 - 1].desc.type(), + self._ops[idx2 - 1].desc.input_arg_names(), + self._ops[idx2 - 1].desc.output_arg_names(), + ) + ) return segments def modify_forward_desc_for_recompute(self, dist_context): """ - If program's foward part has 'dropout' op, this function will insert + If program's foward part has 'dropout' op, this function will insert a seed op before it to guarantee that two dropout op have the same outputs. """ op_types = [op.desc.type() for op in self._ops] @@ -116,45 +133,52 @@ class RecomputeState(ProgramStats): cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op) # insert seed op to guarantee that two dropout op have the same outputs - op_unique_name = unique_name.generate("seed") - var_unique_name = unique_name.generate_with_ignorable_key(".".join( - [op_unique_name, 'tmp'])) + # NOTE Hack for adopt recompute for random control, for more info see dist_dropout.py + # new seed added by recompute should have a prefix to distinguish with seed added by user or other moudule. + op_unique_name = unique_name.generate("rc_seed") + var_unique_name = unique_name.generate_with_ignorable_key( + ".".join([op_unique_name, 'tmp']) + ) seed_var = self._block.create_var( name=var_unique_name, dtype='int32', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) # set new seed_var's dist_attr ref_dims_mapping = [-1] ref_process_mesh = cur_op_dist_attr.process_mesh - seed_var_dist_attr = set_var_dist_attr(dist_context, seed_var, - ref_dims_mapping, - ref_process_mesh) - - seed = 0 if cur_op.attr("fix_seed") is False else int( - cur_op.attr("seed")) + seed_var_dist_attr = set_var_dist_attr( + dist_context, seed_var, ref_dims_mapping, ref_process_mesh + ) + + seed = ( + 0 + if cur_op.attr("fix_seed") is False + else int(cur_op.attr("seed")) + ) seed_op = self._block._insert_op_without_sync( index=cur_op.idx, type="seed", inputs={}, outputs={"Out": seed_var}, - attrs={ - "seed": seed, - "force_cpu": True - }) + attrs={"seed": seed, "force_cpu": True}, + ) # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - seed_op, ref_process_mesh, ref_dims_mapping, dist_context) + seed_op, ref_process_mesh, ref_dims_mapping, dist_context + ) # modify dropout op's desc self._ops.insert(op_idx, seed_op) cur_op.desc.set_input("Seed", [var_unique_name]) cur_op._remove_attr("fix_seed") cur_op._remove_attr("seed") - cur_op_dist_attr.set_input_dist_attr(seed_var.name, - seed_var_dist_attr) + cur_op_dist_attr.set_input_dist_attr( + seed_var.name, seed_var_dist_attr + ) op_idx += 2 self._block._sync_with_cpp() @@ -168,7 +192,7 @@ def _find_op_index(block, cur_op): def _get_stop_gradients(program, no_grad_set): - """ get no grad var """ + """get no grad var""" if no_grad_set is None: no_grad_set = set() else: @@ -185,8 +209,9 @@ def _get_stop_gradients(program, no_grad_set): return no_grad_set_name -def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars, - dist_context): +def _add_needed_descs_to_block( + descs, block, main_block, in_memory_vars, dist_context +): """ Get the recomputed ops which will insert the backward part """ @@ -217,7 +242,6 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars, @register_pass("auto_parallel_recompute") class RecomputePass(PassBase): - def __init__(self): super(RecomputePass, self).__init__() self.set_attr("checkpoints", None) @@ -261,12 +285,15 @@ class RecomputePass(PassBase): vars_should_be_hold = [] for segment in segments: vars_should_be_hold.extend( - rc_state.get_out_of_subgraph_vars(segment[0], segment[1])) + rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) + ) cross_vars = set(vars_should_be_hold) - set(checkpoints) logging.info( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( - len(cross_vars), cross_vars)) + len(cross_vars), cross_vars + ) + ) vars_should_be_hold.extend(rc_state.get_reserved_vars()) vars_should_be_hold.extend(rc_state.get_input_nodes()) vars_should_be_hold = list(set(vars_should_be_hold)) @@ -277,14 +304,15 @@ class RecomputePass(PassBase): ckpt_ops_dict = {} buffer_block = main_block.program._create_block() for i, segment in enumerate(segments[::-1]): - fwd_ops = op_path[segment[0]:segment[1]] + fwd_ops = op_path[segment[0] : segment[1]] var_suffix = ".subprog_%d" % i for op in fwd_ops: input_and_output_names = [] input_and_output_names.extend(op.desc.input_arg_names()) input_and_output_names.extend(op.desc.output_arg_names()) - cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( - op) + cur_op_dist_attr = ( + self._dist_context.get_op_dist_attr_for_program(op) + ) assert cur_op_dist_attr is not None for name in input_and_output_names: if main_block.var(name).persistable or name in checkpoints: @@ -294,11 +322,13 @@ class RecomputePass(PassBase): if name not in var_name_dict: ref_process_mesh = cur_op_dist_attr.process_mesh if name in op.desc.input_arg_names(): - ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping( - name) + ref_dims_mapping = ( + cur_op_dist_attr.get_input_dims_mapping(name) + ) else: - ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping( - name) + ref_dims_mapping = ( + cur_op_dist_attr.get_output_dims_mapping(name) + ) # record recomputed var's old_name and new_name (old_name.subprog_XXX) # create new var with new name var_name_dict[name] = name + var_suffix @@ -309,15 +339,23 @@ class RecomputePass(PassBase): dtype=ref_var.dtype, type=ref_var.type, persistable=ref_var.persistable, - stop_gradient=ref_var.stop_gradient) + stop_gradient=ref_var.stop_gradient, + ) # set new recomputed var's dist attr - set_var_dist_attr(self._dist_context, rc_var, - ref_dims_mapping, ref_process_mesh) + set_var_dist_attr( + self._dist_context, + rc_var, + ref_dims_mapping, + ref_process_mesh, + ) # get recomputed segment's descs - segment_descs = _add_needed_descs_to_block(fwd_ops, buffer_block, - main_block, - vars_in_memory, - self._dist_context) + segment_descs = _add_needed_descs_to_block( + fwd_ops, + buffer_block, + main_block, + vars_in_memory, + self._dist_context, + ) # rename recomputed ops' input and output var name for key in var_name_dict: _rename_arg_(segment_descs, key, var_name_dict[key]) @@ -345,7 +383,10 @@ class RecomputePass(PassBase): # rename grad op's var_name which is not in 'vars_in_memory' for key in var_name_dict: - if key not in grad_op.input_arg_names + grad_op.output_arg_names: + if ( + key + not in grad_op.input_arg_names + grad_op.output_arg_names + ): continue self.reset_op_dist_attr(grad_op, var_name_dict) _rename_arg_([grad_op.desc], key, var_name_dict[key]) @@ -360,17 +401,20 @@ class RecomputePass(PassBase): idx -= 1 segment_descs = ckpt_ops_dict[fwd_op_id][1] for _, op_desc in reversed(list(enumerate(segment_descs))): - rc_op = main_block._insert_op_without_sync(idx, - type='nop') + rc_op = main_block._insert_op_without_sync( + idx, type='nop' + ) rc_desc = rc_op.desc rc_desc.copy_from(op_desc) rc_desc.set_original_id(rc_desc.id()) # set recomputed ops' dist attr fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( - op_desc.original_id()) + op_desc.original_id() + ) assert fwd_op_dist_attr is not None - self.set_op_dist_attr(rc_op, fwd_op_dist_attr, - var_name_dict) + self.set_op_dist_attr( + rc_op, fwd_op_dist_attr, var_name_dict + ) ckpt_ops_dict[fwd_op_id][0] = False @@ -382,13 +426,15 @@ class RecomputePass(PassBase): for input in op.desc.input_arg_names(): if input in var_name_dict.keys(): in_dist_attr = op_dist_attr.get_input_dist_attr(input) - op_dist_attr.set_input_dist_attr(var_name_dict[input], - in_dist_attr) + op_dist_attr.set_input_dist_attr( + var_name_dict[input], in_dist_attr + ) for output in op.desc.output_arg_names(): if output in var_name_dict.keys(): out_dist_attr = op_dist_attr.get_output_dist_attr(output) - op_dist_attr.set_output_dist_attr(var_name_dict[output], - out_dist_attr) + op_dist_attr.set_output_dist_attr( + var_name_dict[output], out_dist_attr + ) def set_op_dist_attr(self, op, old_dist_attr, var_name_dict): new_dist_attr = OperatorDistributedAttribute() @@ -399,16 +445,18 @@ class RecomputePass(PassBase): for input in old_dist_attr.inputs_dist_attrs.keys(): if input in var_name_dict.keys(): in_dist_attr = old_dist_attr.inputs_dist_attrs[input] - new_dist_attr.set_input_dist_attr(var_name_dict[input], - in_dist_attr) + new_dist_attr.set_input_dist_attr( + var_name_dict[input], in_dist_attr + ) else: in_dist_attr = old_dist_attr.inputs_dist_attrs[input] new_dist_attr.set_input_dist_attr(input, in_dist_attr) for output in old_dist_attr.outputs_dist_attrs.keys(): if output in var_name_dict.keys(): out_dist_attr = old_dist_attr.outputs_dist_attrs[output] - new_dist_attr.set_output_dist_attr(var_name_dict[output], - out_dist_attr) + new_dist_attr.set_output_dist_attr( + var_name_dict[output], out_dist_attr + ) else: out_dist_attr = old_dist_attr.outputs_dist_attrs[output] new_dist_attr.set_output_dist_attr(output, out_dist_attr) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 446461a045b286d455cf8e5f10b8c3ab052ae7a4..9079a0b75357fd93b172cce0730501de7cbf5753 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ${dist_ENVS}) set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS}) + set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 50) py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS ${dist_ENVS}) set_tests_properties(test_iterable_dataset diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index 318773c71e09eb28e1a3a8423b738c6696b91fc7..71f16f97206185aed848441d65a60d425d8c0c9f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -21,14 +21,17 @@ from paddle.distributed.fleet import auto sys.path.append("..") import auto_parallel_gpt_model as modeling -from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion +from auto_parallel_gpt_model import ( + GPTModel, + GPTForPretraining, + GPTPretrainingCriterion, +) sequence_len = 512 vocab_size = 1000 class FakeDataset(paddle.io.Dataset): - def __init__(self, num_samples): self.num_samples = num_samples self.sequence_len = sequence_len @@ -40,8 +43,11 @@ class FakeDataset(paddle.io.Dataset): random.seed(2021) tokens = np.random.randint(self.vocab_size, size=self.sequence_len) position_ids = np.arange(self.sequence_len) - attention_mask = np.tril(np.ones(self.sequence_len)).reshape( - (1, self.sequence_len, self.sequence_len)).astype(np.float32) + attention_mask = ( + np.tril(np.ones(self.sequence_len)) + .reshape((1, self.sequence_len, self.sequence_len)) + .astype(np.float32) + ) labels = np.random.randint(self.vocab_size, size=self.sequence_len) loss_mask = np.ones(self.sequence_len).astype(np.float32) return tokens, position_ids, attention_mask, labels, loss_mask @@ -51,30 +57,32 @@ class FakeDataset(paddle.io.Dataset): def create_data_holder(batch_size): - tokens = paddle.static.InputSpec(name="tokens", - shape=[batch_size, sequence_len], - dtype='int64') - position_ids = paddle.static.InputSpec(name="position_ids", - shape=[batch_size, sequence_len], - dtype='int64') + tokens = paddle.static.InputSpec( + name="tokens", shape=[batch_size, sequence_len], dtype='int64' + ) + position_ids = paddle.static.InputSpec( + name="position_ids", shape=[batch_size, sequence_len], dtype='int64' + ) attention_mask = paddle.static.InputSpec( name="attention_mask", shape=[batch_size, 1, sequence_len, sequence_len], - dtype='float32') - labels = paddle.static.InputSpec(name="labels", - shape=[batch_size, sequence_len], - dtype='int64') - loss_mask = paddle.static.InputSpec(name="loss_mask", - shape=[batch_size, sequence_len], - dtype='float32') + dtype='float32', + ) + labels = paddle.static.InputSpec( + name="labels", shape=[batch_size, sequence_len], dtype='int64' + ) + loss_mask = paddle.static.InputSpec( + name="loss_mask", shape=[batch_size, sequence_len], dtype='float32' + ) return [tokens, position_ids, attention_mask], [labels, loss_mask] -def generate_model(strategy): +def generate_model(strategy, dropout_prob=0.0): modeling.init_global() ranks = list(range(paddle.distributed.get_world_size())) - modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks, - dim_names=["x"]) + modeling._global_process_mesh = auto.ProcessMesh( + mesh=ranks, dim_names=["x"] + ) if strategy == "serial": modeling._global_parallel_strategy = "serial" elif strategy == "mp": @@ -84,24 +92,25 @@ def generate_model(strategy): else: raise ValueError("Only support serial, mp2 and dp2.") - gpt = GPTModel(vocab_size=1000, - hidden_size=64, - num_hidden_layers=2, - num_attention_heads=8, - intermediate_size=256, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - max_position_embeddings=1024, - type_vocab_size=1, - initializer_range=0.02, - pad_token_id=0, - eos_token_id=7, - bos_token_id=0, - eol_token_id=3) - model = GPTForPretraining(gpt, - vocab_size=1000, - hidden_size=64, - initializer_range=0.02) + gpt = GPTModel( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=256, + hidden_act="gelu", + hidden_dropout_prob=dropout_prob, + attention_probs_dropout_prob=dropout_prob, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + ) + model = GPTForPretraining( + gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 + ) criterion = GPTPretrainingCriterion() return model, criterion diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/random_control_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/random_control_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..8145f254ec8723ea6841a0fdd322cd0aaae95a39 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/random_control_unittest.py @@ -0,0 +1,273 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle + +paddle.enable_static() +from paddle import _legacy_C_ops +from paddle.distributed.fleet import auto + + +def dy_broadcast_helper(tensor): + _legacy_C_ops.c_broadcast( + tensor, tensor, 'root', 1, 'use_calc_stream', True, 'ring_id', 1000 + ) + _legacy_C_ops.c_sync_calc_stream(tensor, tensor) + + +def apply_pass(use_recompute=False, no_recompute_segments=[]): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_recompute: + recompute = strategy.recompute + recompute.enable = True + recompute.no_recompute_segments = no_recompute_segments + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestRandomControl(unittest.TestCase): + def setUp(self): + self.rtol = 1e-6 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + paddle.distributed.auto_parallel.parallel_manual_seed(100) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_recompute=False, no_recompute_segments=[]): + reset_prog() + + strategy = apply_pass(use_recompute, no_recompute_segments) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp", dropout_prob=0.1) + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def compare_mask_between_ranks( + self, rank, mask_np_list, comapre_idx, equal + ): + + for np_mask in [mask_np_list[i] for i in comapre_idx]: + mask_tensor_local = paddle.to_tensor(np_mask.astype("float32")) + if rank == 0: + mask_tensor_remote = paddle.ones_like(mask_tensor_local) + dy_broadcast_helper(mask_tensor_remote) + if equal: + assert np.array_equal( + mask_tensor_remote.numpy(), mask_tensor_local.numpy() + ) + else: + assert not np.array_equal( + mask_tensor_remote.numpy(), + mask_tensor_local.numpy(), + ) + else: + dy_broadcast_helper(mask_tensor_local) + + def test_random_ctrl_vanilla(self): + # mp2 recompute training + rc_engine = self.get_engine(False) + train_dataloader = rc_engine.dataloader( + self.dataset, + batch_size=self.batch_size, + mode="train", + sample_split=3, + ) + + rc_engine.prepare(mode="train") + mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)] + mask_var_list = [ + rc_engine.main_program.global_block().var(varname) + for varname in mask_name_list + ] + + for data in train_dataloader: + outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train") + mask_np_list = [outs['fetches'][varname] for varname in mask_name_list] + + paddle.disable_static() + rank = paddle.distributed.get_rank() + # check globl mask consistent across ranks + global_index = [0, 2, 3, 5, 6] + self.compare_mask_between_ranks( + rank, mask_np_list, global_index, equal=True + ) + local_index = [1, 4] + # check loacl mask different across ranks + self.compare_mask_between_ranks( + rank, mask_np_list, local_index, equal=False + ) + paddle.enable_static() + + # check program + ops = rc_engine.main_program.global_block().ops + rng_names = [] + seed_var_names = [] + for op in ops: + if op.type == "seed": + rng_names.append(op.attr('rng_name')) + if op.type == "dropout": + seed_var_names.append(op.input("Seed")[0]) + rank = paddle.distributed.get_rank() + + self.assertEqual( + rng_names, + [ + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + ], + ) + self.assertEqual( + seed_var_names, + [ + 'tensor_parallel_seed.tmp_0', + 'tensor_parallel_seed.tmp_1', + 'tensor_parallel_seed.tmp_2', + 'tensor_parallel_seed.tmp_3', + 'tensor_parallel_seed.tmp_4', + 'tensor_parallel_seed.tmp_5', + 'tensor_parallel_seed.tmp_6', + ], + ) + + def test_random_ctrl_with_recompute(self): + # mp2 recompute training + rc_engine = self.get_engine(True) + train_dataloader = rc_engine.dataloader( + self.dataset, + batch_size=self.batch_size, + mode="train", + sample_split=3, + ) + + rc_engine.prepare(mode="train") + mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)] + recompute_mask_name_list = [ + 'dropout_0.tmp_1.subprog_1', + 'dropout_1.tmp_1.subprog_1', + 'dropout_2.tmp_1.subprog_1', + 'dropout_3.tmp_1.subprog_1', + 'dropout_4.tmp_1.subprog_0', + 'dropout_5.tmp_1.subprog_0', + 'dropout_6.tmp_1.subprog_0', + ] + mask_var_list = [ + rc_engine.main_program.global_block().var(varname) + for varname in mask_name_list + recompute_mask_name_list + ] + + for data in train_dataloader: + outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train") + mask_np_list = [ + outs['fetches'][varname] + for varname in mask_name_list + recompute_mask_name_list + ] + + # check recompute is mask the same within local device + for i in range(7): + mask_fw = mask_np_list[i].astype("float32") + mask_rc = mask_np_list[i + 7].astype("float32") + assert np.array_equal( + mask_fw, + mask_rc, + ) + + paddle.disable_static() + # check globl mask consistent across ranks + rank = paddle.distributed.get_rank() + global_index = [0, 2, 3, 5, 6] + self.compare_mask_between_ranks( + rank, mask_np_list, global_index, equal=True + ) + local_index = [1, 4] + # check loacl mask different across ranks + self.compare_mask_between_ranks( + rank, mask_np_list, local_index, equal=False + ) + paddle.enable_static() + + # check program + rank = paddle.distributed.get_rank() + ops = rc_engine.main_program.global_block().ops + rng_names = [] + seed_var_names = [] + for op in ops: + if op.type == "seed": + rng_names.append(op.attr('rng_name')) + if op.type == "dropout": + seed_var_names.append(op.input("Seed")[0]) + + self.assertEqual( + rng_names, + [ + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + ], + ) + self.assertEqual( + seed_var_names, + [ + 'rc_seed_0.tmp_0', + 'rc_seed_1.tmp_0', + 'rc_seed_2.tmp_0', + 'rc_seed_3.tmp_0', + 'rc_seed_4.tmp_0', + 'rc_seed_5.tmp_0', + 'rc_seed_6.tmp_0', + 'rc_seed_4.tmp_0', + 'rc_seed_5.tmp_0', + 'rc_seed_6.tmp_0', + 'rc_seed_0.tmp_0', + 'rc_seed_1.tmp_0', + 'rc_seed_2.tmp_0', + 'rc_seed_3.tmp_0', + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_random_ctrl.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_random_ctrl.py new file mode 100644 index 0000000000000000000000000000000000000000..6162db5e93ee7f3fee75932821cc8bf5fcdaa542 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_random_ctrl.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestRandomCtrlPass(unittest.TestCase): + def test_mp2_with_recompute(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "random_control_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main()