diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 269a0ec644dbd2e54ab255fa88f2e80ae744985f..835ca68df2d1c1169b9eaca767715070ca00e013 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -19,5 +19,6 @@ from .interface import shard_tensor from .interface import shard_op from .interface import recompute from .interface import fetch +from .random import parallel_manual_seed __all__ = [] diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 4a6181758d114d0e462310917ca6385da56a2c57..a84bea42d538fc23713920c88ac8bd392603cc75 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -45,7 +45,7 @@ from .dist_loader import ( from .dist_op import DistributedOperator from .dist_saver import DistributedSaver from .helper import ProgramHelper -from .interface import CollectionNames, get_collection +from .interface import CollectionNames, fetch, get_collection from .parallelizer_v2 import Parallelizer from .planner_v2 import Planner from .process_group import get_all_process_groups, new_process_group @@ -410,6 +410,8 @@ class Engine: ), "user_fetches must be a list, but receive {}".format( type(user_fetches).__name__ ) + else: + user_fetches = [] fetch_names = [] fetch_indices = [] @@ -434,10 +436,13 @@ class Engine: _process_fetch_group("metrics_" + str(i), var_list) if mode == "predict": _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"]) + for usr_fetch in user_fetches: + var_name = _to_name_str(usr_fetch) + fetch(var_name) user_fetches_collection = [ item[1] for item in get_collection(CollectionNames.FETCHES) ] - var_list = (user_fetches_collection or []) + (user_fetches or []) + var_list = user_fetches_collection or [] _process_fetch_group("fetches", var_list) return fetch_names, fetch_indices diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 406ec4d8b36da07f56d1883dea77116f2ca6d165..bc5bf4b7379e7259ea716fb017b8b232d1554e06 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -36,3 +36,4 @@ from . import dist_reduce_sum_p from . import dist_shape from . import dist_assign from . import dist_scale +from . import dist_dropout diff --git a/python/paddle/distributed/auto_parallel/operators/dist_dropout.py b/python/paddle/distributed/auto_parallel/operators/dist_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..e43870b26883ba2210aef8d8729886e61d99b8d6 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_dropout.py @@ -0,0 +1,186 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import logging + +import paddle +from paddle.framework import core +from paddle.utils import unique_name + +from ...utils.log_utils import get_logger + +_logger = get_logger(logging.INFO) +from ..random import determinate_rng, is_enable_auto_rand_ctrl +from ..utils import ( + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_var_dist_attr, +) +from .common import ( + DistributedOperatorImplContainer, + register_distributed_operator_impl, + register_distributed_operator_impl_container, +) +from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0 + + +class DistributedDropout(DistributedOperatorImplContainer): + def __init__(self, op_type): + super().__init__(op_type) + + +register_distributed_operator_impl_container(DistributedDropout("dropout")) + + +# Dist Dropout with Random Control +# Dropout re-use the compatible and cost function of elementwise +class DistributedDropoutImpl0(DistributedElementwiseImpl0): + def __init__(self, name): + super().__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + + if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute: + + assert ( + op_dist_attr is not None + ), f"forward op [{str(src_op)}] don't have dist attribute !" + + # check validation of inputs / outputs + assert 'X' in kwargs, "input [{}] is not given".format('X') + assert ( + len(kwargs['X']) == 1 + ), "input X should be only one tensor but got {}".format( + kwargs['X'] + ) + assert 'Seed' in kwargs, "input [{}] is not given".format('Seed') + + if ( + src_op.has_attr("fix_seed") + and src_op.attr("fix_seed") + and src_op.has_attr("seed") + and src_op.attr("seed") + ): + _logger.info( + "Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format( + src_op + ) + ) + elif rank_id not in op_dist_attr.process_mesh.process_ids: + pass + # NOTE Adopt for recompute + # If user already set seed, We should not modify it. But if the seed is added by recompute pass, it should be under control. + # TODO in future recompute pass should happen after parallel partitione. and remove this at that time. + elif len(kwargs['Seed']) > 0 or len(src_op.input("Seed")) > 0: + seed_var_name = kwargs['Seed'][0] + if seed_var_name.startswith('rc_seed'): + pre_op = main_block.ops[-1] + assert ( + pre_op.type == "seed" + and len(pre_op.attr("rng_name")) == 0 + ), f"found exception op {str(pre_op)}" + + # determinate rng + X_var = main_block._var_recursive(kwargs['X'][0]) + X_dims_mapping = op_dist_attr.get_input_dims_mapping( + X_var.name + ) + process_mesh = op_dist_attr.process_mesh + rng_name = determinate_rng( + rank_id, X_dims_mapping, process_mesh + ) + # make recompute seed under control + pre_op._set_attr("rng_name", rng_name) + pre_op._set_attr("deterministic", True) + pre_op._set_attr("force_cpu", True) + else: + _logger.info( + "Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format( + src_op + ) + ) + else: + # determinate rng + X_var = main_block._var_recursive(kwargs['X'][0]) + X_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name) + process_mesh = op_dist_attr.process_mesh + + rng_name = determinate_rng( + rank_id, X_dims_mapping, process_mesh + ) + assert rng_name is not None and rng_name != "" + + # insert seed op + seed_var = main_block.create_var( + name=unique_name.generate_with_ignorable_key( + ".".join(["tensor_parallel_seed", 'tmp']) + ), + dtype=paddle.int32, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False, + ) + + # set new seed_var's dist_attr + seed_var_dims_mapping = [-1] + seed_var_dist_attr = set_var_dist_attr( + ctx, seed_var, seed_var_dims_mapping, process_mesh + ) + + # adopt for recompute + # force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang + seed_op = main_block.append_op( + type='seed', + outputs={'Out': seed_var}, + attrs={ + 'deterministic': True, + 'rng_name': rng_name, + 'force_cpu': True, + }, + ) + seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed') + # set new seed op's dist_attr + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + seed_op, process_mesh, seed_var_dims_mapping, ctx + ) + + # modify dropout op + src_op.desc.set_input("Seed", [seed_var.name]) + src_op._remove_attr("fix_seed") + src_op._remove_attr("seed") + op_dist_attr.set_input_dist_attr( + seed_var.name, seed_var_dist_attr + ) + kwargs['Seed'] = [seed_var.name] + + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + # dropout backward is deterministic by mask, and not need for random state control + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl( + "dropout", DistributedDropoutImpl0("random_control") +) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 92fc5f31a81eb94ce2861348ebdc09180a3879d3..7176341feedfb9ed7c863882ea1cd855f403a356 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -362,7 +362,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert ( op_dist_attr is not None - ), f"backward op [{str(src_op)}] don't have dist attribute !" + ), f"forward op [{str(src_op)}] don't have dist attribute !" # check validation of inputs / outputs assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 3f7c3999cebef1e0a16c67f911c42583a110bc05..a76a3f5dcb9abd69419e480b59db1ee1ad573c3f 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -23,6 +23,7 @@ from paddle.utils import unique_name from ..utils.log_utils import get_logger from .partitioner import Partitioner from .process_group import get_world_process_group +from .random import init_auto_parallel_rng from .reshard import Resharder from .utils import set_grad_var_shape @@ -83,6 +84,9 @@ class Parallelizer: ) = partitioner.partition( serial_main_program, serial_startup_program, params_grads ) + + init_auto_parallel_rng() + self._logger.debug( "within parallel partitioner time: {}, mode {}".format( time.time() - time0, self._mode diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 531de9b545e9373fcd8941f0cf3e4788b50b83e9..e2ccd16aaaad4bd150a09cf65995ecfbf2c6da9e 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -22,6 +22,8 @@ from paddle.framework import core # Use to store the previous and current process mesh _g_previous_process_mesh = None _g_current_process_mesh = None +# {shape_process_ids : unique_id} +_g_unique_process_mesh_map = {} def get_current_process_mesh(): @@ -42,6 +44,30 @@ def reset_current_process_mesh(): _g_current_process_mesh = _g_previous_process_mesh +def get_unique_id_for_process_mesh(shape, process_ids): + key = f"shape {shape}, process_ids {process_ids}" + global _g_unique_process_mesh_map + if key in _g_unique_process_mesh_map: + unique_id = _g_unique_process_mesh_map[key] + else: + unique_id = len(_g_unique_process_mesh_map) + 1 + _g_unique_process_mesh_map[key] = unique_id + + return unique_id + + +def retrive_unique_id_for_process_mesh(shape, process_ids): + key = f"shape {shape}, process_ids {process_ids}" + global _g_unique_process_mesh_map + assert key in _g_unique_process_mesh_map + return _g_unique_process_mesh_map[key] + + +def get_unique_process_mesh_map(): + global _g_unique_process_mesh_map + return _g_unique_process_mesh_map + + class ProcessMesh(core.ProcessMesh): """ The `ProcessMesh` object describes the Cartesian topology of the used processes. @@ -124,6 +150,11 @@ class ProcessMesh(core.ProcessMesh): pg0 = get_process_group(0) pg0.add_ranks(self.process_ids) + # Uniqe Mesh Id + self._unique_id = get_unique_id_for_process_mesh( + self._shape, self._process_ids + ) + @property def mesh(self): """ @@ -131,6 +162,16 @@ class ProcessMesh(core.ProcessMesh): """ return self._mesh + @property + def unique_id(self): + """ + Get the unique id of ProcessMesh. + NOTE + Unique id only take process_ids and shape into account. + Different ProcessMesh with same process_ids and shape have same unique id. + """ + return self._unique_id + def __getitem__(self, index): if isinstance(index, tuple): new_dim_names = [] diff --git a/python/paddle/distributed/auto_parallel/random.py b/python/paddle/distributed/auto_parallel/random.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca6d9e9ea06961e67b82b6962f241dfcc8ce64e --- /dev/null +++ b/python/paddle/distributed/auto_parallel/random.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import paddle + +from ..utils.log_utils import get_logger +from .process_mesh import retrive_unique_id_for_process_mesh +from .utils import _get_idx_in_axis + +_logger = get_logger(logging.INFO) + +_rng_name_to_seed = {} +_inited_rng_name_to_seed = {} +_enable_random_control = False +_basic_seed = 42 + +# use Prime number as offset to avoid confict +_mesh_offset = 173 +_dim_offsets = [11, 23, 37, 73] + + +def is_enable_auto_rand_ctrl(): + global _enable_random_control + return _enable_random_control + + +def enable_auto_rand_ctrl(): + global _enable_random_control + _enable_random_control = True + + +def parallel_manual_seed(seed): + """Enable auto parallel random control. + Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order). + * Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness. + + * Consistency: Meanwhile if the tensor is **Replicated** on another Mesh dimension, randomness of Devices along that Mesh dimension should be Consistent. + + For instance: rank0 ~ rank7 consist a Mesh of shape of [2, 4]; A 2D tensor is distributed in that Mesh using dims_mapping [-1, 1]. + Randomness for rank0-rank1-rank2-rank3 (rank4-rank5-rank6-rank7) should be Independent; + Randomness for rank0 and rank4 (rank1 and rank5, ...) should be Consistent. + + This function should be called only once before auto parallel compiles the computation graph (e.g. auto_parallel.engine.prepare() or fit()). + + This seed only affects how randomness-relative **operators** (dropout, fuse op with dropout inside, etc) are execute amonge mesh, and would NOT affect other processe like Parameter initialization. + + Examples: + # seed relative to training step + auto_parallel_random_seed((step + 13) * 257) + ... + engine.prepare() + """ + + enable_auto_rand_ctrl() + global _basic_seed + _basic_seed = seed + + +def determinate_rng(rank, dims_mapping, process_mesh): + + # TODO(JZ-LIANG) Support Mesh with any high rank + # use a string to unique integer hashing algorithm for seed computation. + # instead of using offsets to coodinate seed across devices. + if len(process_mesh.shape) > 4: + raise NotImplementedError( + "Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {}".format( + str(process_mesh) + ) + ) + global _basic_seed + seed_ = _basic_seed + + # FIXME + # unique_id = process_mesh.unique_id + unique_id = retrive_unique_id_for_process_mesh( + process_mesh.shape, process_mesh.process_ids + ) + sharding_expr = f'mesh:{unique_id}' + seed_ += _mesh_offset * (unique_id + 1) + + for i in range(len(process_mesh.shape)): + if i not in dims_mapping: + relative_idx = -1 + else: + relative_idx = _get_idx_in_axis( + process_mesh.process_ids, + process_mesh.shape, + i, + rank, + ) + + sharding_expr += f"_dim{i}:{relative_idx}" + seed_ += _dim_offsets[i] * (relative_idx + 1) + + global _rng_name_to_seed + if sharding_expr in _rng_name_to_seed: + assert _rng_name_to_seed[sharding_expr] == seed_ + else: + assert ( + seed_ not in _rng_name_to_seed.values() + ), "Seed Confilt! current seed: {}, current sharding expr: {}, generated seed: {}".format( + seed_, sharding_expr, _rng_name_to_seed + ) + _rng_name_to_seed[sharding_expr] = seed_ + + return sharding_expr + + +def init_auto_parallel_rng(): + + if not is_enable_auto_rand_ctrl(): + return + + global _rng_name_to_seed + # NOTE init rng maybe call multiple times, avoid init same rng twice + global _inited_rng_name_to_seed + + for rng_name, seed in _rng_name_to_seed.items(): + if rng_name in _inited_rng_name_to_seed: + assert _inited_rng_name_to_seed[rng_name] == seed + else: + _logger.info( + f"Init Auto Parallel RNG: {rng_name}, with seed {seed}" + ) + paddle.framework.random.set_random_seed_generator(rng_name, seed) + _inited_rng_name_to_seed[rng_name] = seed diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index b6a13540caf098dfdca5223a7484f162bb0137b7..cb3dd480fbf47cbde9bf6db11bd3456544695f70 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -136,7 +136,9 @@ class RecomputeState(ProgramStats): cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op) # insert seed op to guarantee that two dropout op have the same outputs - op_unique_name = unique_name.generate("seed") + # NOTE Hack for adopt recompute for random control, for more info see dist_dropout.py + # new seed added by recompute should have a prefix to distinguish with seed added by user or other moudule. + op_unique_name = unique_name.generate("rc_seed") var_unique_name = unique_name.generate_with_ignorable_key( ".".join([op_unique_name, 'tmp']) ) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 7db4d58bd8b917660a627b184a8421144bb61344..d1ba09ee8b47dec9ef69b7c62d77dec3c299dd4b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip) set_tests_properties(test_pass_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_random_ctrl MODULES test_random_ctrl) + set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 50) py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge) set_tests_properties(test_pass_gradient_merge PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index 35bf1a323d15c4a0f07d51bbb9f3394ab65ee4b9..f23b3faf8dfe6199c6bfa3619d9b1b961352cf17 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -75,7 +75,7 @@ def create_data_holder(batch_size, vocab_size=1000, sequence_len=512): return [tokens, position_ids, attention_mask], [labels, loss_mask] -def generate_model(strategy): +def generate_model(strategy, dropout_prob=0.0): modeling.init_global() ranks = list(range(paddle.distributed.get_world_size())) modeling._global_process_mesh = auto.ProcessMesh( @@ -97,8 +97,8 @@ def generate_model(strategy): num_attention_heads=8, intermediate_size=256, hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, + hidden_dropout_prob=dropout_prob, + attention_probs_dropout_prob=dropout_prob, max_position_embeddings=1024, type_vocab_size=1, initializer_range=0.02, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/random_control_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/random_control_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..52e6e216074fdbfa08b944e7c83161f2c675c994 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/random_control_unittest.py @@ -0,0 +1,273 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle + +paddle.enable_static() +from paddle import _legacy_C_ops +from paddle.distributed.fleet import auto + + +def dy_broadcast_helper(tensor): + _legacy_C_ops.c_broadcast( + tensor, tensor, 'root', 1, 'use_calc_stream', True, 'ring_id', 0 + ) + _legacy_C_ops.c_sync_calc_stream(tensor, tensor) + + +def apply_pass(use_recompute=False, no_recompute_segments=[]): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_recompute: + recompute = strategy.recompute + recompute.enable = True + recompute.no_recompute_segments = no_recompute_segments + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestRandomControl(unittest.TestCase): + def setUp(self): + self.rtol = 1e-6 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + paddle.distributed.auto_parallel.parallel_manual_seed(100) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_recompute=False, no_recompute_segments=[]): + reset_prog() + + strategy = apply_pass(use_recompute, no_recompute_segments) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp", dropout_prob=0.1) + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def compare_mask_between_ranks( + self, rank, mask_np_list, comapre_idx, equal + ): + + for np_mask in [mask_np_list[i] for i in comapre_idx]: + mask_tensor_local = paddle.to_tensor(np_mask.astype("float32")) + if rank == 0: + mask_tensor_remote = paddle.ones_like(mask_tensor_local) + dy_broadcast_helper(mask_tensor_remote) + if equal: + assert np.array_equal( + mask_tensor_remote.numpy(), mask_tensor_local.numpy() + ) + else: + assert not np.array_equal( + mask_tensor_remote.numpy(), + mask_tensor_local.numpy(), + ) + else: + dy_broadcast_helper(mask_tensor_local) + + def test_random_ctrl_vanilla(self): + # mp2 recompute training + rc_engine = self.get_engine(False) + train_dataloader = rc_engine.dataloader( + self.dataset, + batch_size=self.batch_size, + mode="train", + sample_split=3, + ) + + rc_engine.prepare(mode="train") + mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)] + mask_var_list = [ + rc_engine.main_program.global_block().var(varname) + for varname in mask_name_list + ] + + for data in train_dataloader: + outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train") + mask_np_list = [outs['fetches'][varname] for varname in mask_name_list] + + paddle.disable_static() + rank = paddle.distributed.get_rank() + # check globl mask consistent across ranks + global_index = [0, 2, 3, 5, 6] + self.compare_mask_between_ranks( + rank, mask_np_list, global_index, equal=True + ) + local_index = [1, 4] + # check loacl mask different across ranks + self.compare_mask_between_ranks( + rank, mask_np_list, local_index, equal=False + ) + paddle.enable_static() + + # check program + ops = rc_engine.main_program.global_block().ops + rng_names = [] + seed_var_names = [] + for op in ops: + if op.type == "seed": + rng_names.append(op.attr('rng_name')) + if op.type == "dropout": + seed_var_names.append(op.input("Seed")[0]) + rank = paddle.distributed.get_rank() + + self.assertEqual( + rng_names, + [ + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + ], + ) + self.assertEqual( + seed_var_names, + [ + 'tensor_parallel_seed.tmp_0', + 'tensor_parallel_seed.tmp_1', + 'tensor_parallel_seed.tmp_2', + 'tensor_parallel_seed.tmp_3', + 'tensor_parallel_seed.tmp_4', + 'tensor_parallel_seed.tmp_5', + 'tensor_parallel_seed.tmp_6', + ], + ) + + def test_random_ctrl_with_recompute(self): + # mp2 recompute training + rc_engine = self.get_engine(True) + train_dataloader = rc_engine.dataloader( + self.dataset, + batch_size=self.batch_size, + mode="train", + sample_split=3, + ) + + rc_engine.prepare(mode="train") + mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)] + recompute_mask_name_list = [ + 'dropout_0.tmp_1.subprog_1', + 'dropout_1.tmp_1.subprog_1', + 'dropout_2.tmp_1.subprog_1', + 'dropout_3.tmp_1.subprog_1', + 'dropout_4.tmp_1.subprog_0', + 'dropout_5.tmp_1.subprog_0', + 'dropout_6.tmp_1.subprog_0', + ] + mask_var_list = [ + rc_engine.main_program.global_block().var(varname) + for varname in mask_name_list + recompute_mask_name_list + ] + + for data in train_dataloader: + outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train") + mask_np_list = [ + outs['fetches'][varname] + for varname in mask_name_list + recompute_mask_name_list + ] + + # check recompute is mask the same within local device + for i in range(7): + mask_fw = mask_np_list[i].astype("float32") + mask_rc = mask_np_list[i + 7].astype("float32") + assert np.array_equal( + mask_fw, + mask_rc, + ) + + paddle.disable_static() + # check globl mask consistent across ranks + rank = paddle.distributed.get_rank() + global_index = [0, 2, 3, 5, 6] + self.compare_mask_between_ranks( + rank, mask_np_list, global_index, equal=True + ) + local_index = [1, 4] + # check loacl mask different across ranks + self.compare_mask_between_ranks( + rank, mask_np_list, local_index, equal=False + ) + paddle.enable_static() + + # check program + rank = paddle.distributed.get_rank() + ops = rc_engine.main_program.global_block().ops + rng_names = [] + seed_var_names = [] + for op in ops: + if op.type == "seed": + rng_names.append(op.attr('rng_name')) + if op.type == "dropout": + seed_var_names.append(op.input("Seed")[0]) + + self.assertEqual( + rng_names, + [ + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + f'mesh:1_dim0:{rank}', + 'mesh:1_dim0:-1', + 'mesh:1_dim0:-1', + ], + ) + self.assertEqual( + seed_var_names, + [ + 'rc_seed_0.tmp_0', + 'rc_seed_1.tmp_0', + 'rc_seed_2.tmp_0', + 'rc_seed_3.tmp_0', + 'rc_seed_4.tmp_0', + 'rc_seed_5.tmp_0', + 'rc_seed_6.tmp_0', + 'rc_seed_4.tmp_0', + 'rc_seed_5.tmp_0', + 'rc_seed_6.tmp_0', + 'rc_seed_0.tmp_0', + 'rc_seed_1.tmp_0', + 'rc_seed_2.tmp_0', + 'rc_seed_3.tmp_0', + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_random_ctrl.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_random_ctrl.py new file mode 100644 index 0000000000000000000000000000000000000000..6162db5e93ee7f3fee75932821cc8bf5fcdaa542 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_random_ctrl.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestRandomCtrlPass(unittest.TestCase): + def test_mp2_with_recompute(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "random_control_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main()