From 8cc8e411121649be36af8396536502e7ef7539b7 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 19 Oct 2021 14:59:38 +0800 Subject: [PATCH] [hybrid] static model parallel dropout support deterministic RandomSeedGenerator (#36228) --- paddle/fluid/framework/generator.cc | 37 +++++ paddle/fluid/framework/generator.h | 6 + paddle/fluid/operators/dropout_impl_util.h | 10 +- paddle/fluid/operators/seed_op.cc | 11 ++ paddle/fluid/operators/seed_op.cu | 11 +- paddle/fluid/operators/seed_op.h | 34 +++-- paddle/fluid/pybind/generator_py.cc | 2 + .../meta_parallel/parallel_layers/random.py | 137 ++++++++++++++++++ python/paddle/fluid/backward.py | 6 +- .../fluid/tests/unittests/test_dropout_op.py | 44 ++++++ .../fluid/tests/unittests/test_optimizer.py | 48 +++++- .../fluid/tests/unittests/test_seed_op.py | 32 +++- python/paddle/framework/random.py | 8 + 13 files changed, 354 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index 4b64722a7ab..154154fc795 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -63,6 +63,43 @@ const std::shared_ptr& DefaultCPUGenerator() { return default_cpu_generator; } +using RNGMap = std::unordered_map>; + +static RNGMap& GetRandomSeedGeneratorMap() { + static auto random_seed_generator_map = RNGMap(); + return random_seed_generator_map; +} + +const std::shared_ptr& SetRandomSeedGenerator( + const std::string& name, uint64_t seed) { + auto& rng_map = GetRandomSeedGeneratorMap(); + auto iter = rng_map.find(name); + PADDLE_ENFORCE_EQ(iter == rng_map.end(), true, + platform::errors::AlreadyExists( + "%s RandomSeedGenerator is already exist", name)); + + auto generator = std::make_shared(seed); + bool emplace_success = rng_map.emplace(name, generator).second; + PADDLE_ENFORCE_EQ( + emplace_success, true, + platform::errors::PermissionDenied( + "SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator", + name)); + return rng_map[name]; +} + +const std::shared_ptr& GetRandomSeedGenerator( + const std::string& name) { + auto& rng_map = GetRandomSeedGeneratorMap(); + auto iter = rng_map.find(name); + PADDLE_ENFORCE_EQ(iter != rng_map.end(), true, + platform::errors::NotFound( + "%s RandomSeedGenerator is not found, please " + "use `set_random_seed_generator` to set rng first", + name)); + return iter->second; +} + std::shared_ptr OpDefaultCPUEngine() { static auto op_default_cpu_engine = std::make_shared(); return op_default_cpu_engine; diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h index 862e63c4c6a..d0a5b4443e3 100644 --- a/paddle/fluid/framework/generator.h +++ b/paddle/fluid/framework/generator.h @@ -126,5 +126,11 @@ std::shared_ptr GetCPURandomEngine(uint64_t); const std::shared_ptr& GetDefaultCUDAGenerator( int64_t device_id = -1); +const std::shared_ptr& SetRandomSeedGenerator( + const std::string& name, uint64_t seed); + +const std::shared_ptr& GetRandomSeedGenerator( + const std::string& name); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h index a7188efe713..f2038d12528 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -29,7 +29,7 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId(); auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if ((seed) && platform::is_gpu_place(seed->place())) { + if (seed) { framework::Tensor seed_cpu_tensor; TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); *seed_data = static_cast(seed_cpu_tensor.data()[0]); @@ -39,12 +39,8 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, *seed_data = seed_offset.first; *increment = seed_offset.second; } else { - if (seed) { - *seed_data = *(seed->data()); - } else { - std::random_device rnd; - *seed_data = is_fix_seed ? seed_val : rnd(); - } + std::random_device rnd; + *seed_data = is_fix_seed ? seed_val : rnd(); *increment = offset; } } diff --git a/paddle/fluid/operators/seed_op.cc b/paddle/fluid/operators/seed_op.cc index 32daa8c3934..837ccae0284 100644 --- a/paddle/fluid/operators/seed_op.cc +++ b/paddle/fluid/operators/seed_op.cc @@ -39,6 +39,17 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddOutput("Out", "The output of seed op."); AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddAttr("deterministic", + "(bool, default false) Whether to use deterministic " + "RandomSeedGenerator which " + "generate by `set_random_seed_generator`") + .SetDefault(false) + .AsExtra(); + AddAttr( + "rng_name", + "use deterministic RandomSeedGenerator which name is `rng_name`") + .SetDefault("") + .AsExtra(); AddAttr("force_cpu", "(bool, default false) Force fill output variable to cpu " "memory. Otherwise, fill output variable to the running " diff --git a/paddle/fluid/operators/seed_op.cu b/paddle/fluid/operators/seed_op.cu index 4593b880196..4ca75bcf76e 100644 --- a/paddle/fluid/operators/seed_op.cu +++ b/paddle/fluid/operators/seed_op.cu @@ -23,16 +23,9 @@ class GPUSeedKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *out = context.Output("Out"); - int user_seed = context.Attr("seed"); - auto force_cpu = context.Attr("force_cpu"); - std::random_device rnd; - int seed; - if (user_seed != 0) { - seed = user_seed; - } else { - seed = rnd(); - } + int seed = get_seed(context); + auto force_cpu = context.Attr("force_cpu"); bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace(); if (cpu_place) { platform::DeviceContextPool &pool = diff --git a/paddle/fluid/operators/seed_op.h b/paddle/fluid/operators/seed_op.h index 671f397d4ea..202f25e0b4c 100644 --- a/paddle/fluid/operators/seed_op.h +++ b/paddle/fluid/operators/seed_op.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -20,24 +21,37 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -class CPUSeedKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out = context.Output("Out"); - auto* out_data = out->mutable_data(context.GetPlace()); - int user_seed = context.Attr("seed"); +static int get_seed(const framework::ExecutionContext& context) { + int user_seed = context.Attr("seed"); + bool deterministic = context.Attr("deterministic"); + int seed = 0; + if (!deterministic) { // NOTE: fixed seed should only be used in unittest or for debug. // Guarantee to use random seed in training. - std::random_device rnd; - int seed; if (user_seed != 0) { seed = user_seed; } else { + std::random_device rnd; seed = rnd(); } - out_data[0] = seed; + } else { + std::string name = context.Attr("rng_name"); + auto rng = framework::GetRandomSeedGenerator(name); + do { // NOTE(wangxi): cpu dropout will use random seed if seed == 0 + seed = static_cast(rng->Random64()); + } while (seed == 0); + } + return seed; +} + +template +class CPUSeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out = context.Output("Out"); + auto* out_data = out->mutable_data(context.GetPlace()); + out_data[0] = get_seed(context); } }; diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index 67121e24089..fa924ce6581 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -60,6 +60,8 @@ void BindGenerator(py::module* m_ptr) { &framework::Generator::SetIsInitPy); m.def("default_cpu_generator", &framework::DefaultCPUGenerator); m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator); + m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator); + m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py index ec80ba71036..0a96745c2a4 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py @@ -15,6 +15,11 @@ import paddle import contextlib import numpy as np +from paddle import _C_ops +from paddle.fluid import core +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.framework import in_dygraph_mode, default_main_program +from paddle.fluid.layer_helper import LayerHelper __all__ = [] @@ -93,3 +98,135 @@ def model_parallel_random_seed(seed=None): RNG_STATE_TRACKER.reset() RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed) paddle.seed(global_seed) + + +def determinate_seed(rng_name): + assert rng_name is not None and rng_name != "" + helper = LayerHelper('seed', **locals()) + out = helper.create_variable_for_type_inference(dtype=paddle.int32) + # set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang + helper.append_op( + type='seed', + outputs={'Out': out}, + attrs={'deterministic': True, + 'rng_name': rng_name, + 'force_cpu': True}) + return out + + +def dropout(x, + p=0.5, + axis=None, + rng_name=None, + training=True, + mode="upscale_in_train", + name=None): + """ + Dropout is a regularization technique for reducing overfitting by preventing + neuron co-adaption during training. The dropout operator randomly sets the + outputs of some units to zero, while upscale others according to the given + dropout probability. + + Args: + x (Tensor): The input tensor. The data type is float32 or float64. + p (float|int): Probability of setting units to zero. Default 0.5. + axis (int|list|tuple): The axis along which the dropout is performed. Default None. + rng_name (str): The random seed generator name, which used to obtain deterministic results. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + mode(str): ['upscale_in_train'(default) | 'downscale_in_infer']. + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - dropout_prob ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - dropout_prob) + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor representing the dropout, has same shape and data type as `x` . + + + Examples: + We use ``p=0.5`` in the following description for simplicity. + + 1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly. + + .. code-block:: text + + Let's see a simple case when x is a 2d tensor with shape 2*3: + [[1 2 3] + [4 5 6]] + we generate mask with the same shape as x, which is 2*3. The value of mask is + sampled from a Bernoulli distribution randomly. For example, we may get such mask: + [[0 1 0] + [1 0 1]] + So the output is obtained from elementwise multiply of x and mask: + [[0 2 0] + [4 0 6]] + Using default setting, i.e. ``mode='upscale_in_train'`` , + if in training phase, the final upscale output is: + [[0 4 0 ] + [8 0 12]] + if in test phase, the output is the same as input: + [[1 2 3] + [4 5 6]] + we can also set ``mode='downscale_in_infer'`` , then + if in training phase, the final output is: + [[0 2 0] + [4 0 6]] + if in test phase, the scale output is: + [[0.5 1. 1.5] + [2. 2.5 3. ]] + + """ + if rng_name is None: + return paddle.nn.functional.dropout(x, p, axis, training, mode, name) + + # fast return for p == 0 + if p == 0: return x + + assert isinstance(p, (float, int)), \ + TypeError("p argument should be a number") + assert 0 <= p <= 1, ValueError("p argument should between 0 and 1") + assert mode in ('downscale_in_infer', 'upscale_in_train'), \ + ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + + assert axis is None, \ + TypeError("unsupport axis when using random seed generator") + + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + + # dygraph using tracker, doesn't need determinate seed + if in_dygraph_mode(): + out, mask = _C_ops.dropout(x, 'dropout_prob', p, 'is_test', + not training, 'fix_seed', False, 'seed', 0, + 'dropout_implementation', mode) + return out + + seed = determinate_seed(rng_name) + + helper = LayerHelper('dropout', **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'dropout') + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + mask = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + + helper.append_op( + type='dropout', + inputs={'X': [x], + 'Seed': seed}, + outputs={'Out': [out], + 'Mask': [mask]}, + attrs={ + 'dropout_prob': p, + 'is_test': not training, + 'dropout_implementation': mode, + }) + return out diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 7ab060be6df..d62f7b59411 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -175,11 +175,15 @@ class ProgramStats(object): return op_idx = 0 - while (op_idx < len(self.ops)): + while op_idx < len(self.ops): op = self.ops[op_idx] if op.desc.type() != "dropout": op_idx += 1 continue + # already insert seed op before dropout + if op.input('Seed') is not None and len(op.input('Seed')) == 1: + op_idx += 1 + continue # add a seed op so that the two dropout op can generate same output op_unique_name = unique_name.generate("seed") var_unique_name = unique_name.generate_with_ignorable_key(".".join( diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 396d55b3d0a..bf10e07ba0d 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -19,6 +19,7 @@ import numpy as np import paddle.fluid.core as core from op_test import OpTest, skip_check_grad_ci import paddle +import paddle.static as static import paddle.fluid as fluid from paddle.fluid import Program, program_guard @@ -856,5 +857,48 @@ class TestAlphaDropoutCAPI(unittest.TestCase): self.assertTrue(np.allclose(result.numpy(), result_np)) +class TestDropoutWithDeterminateSeedGenerator(unittest.TestCase): + def setUp(self): + paddle.framework.random.set_random_seed_generator('seed0', 123) + paddle.framework.random.set_random_seed_generator('seed1', 123) + rng0 = paddle.framework.random.get_random_seed_generator('seed0') + rng1 = paddle.framework.random.get_random_seed_generator('seed1') + self.places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import dropout + with static.program_guard(static.Program(), static.Program()): + input = static.data(name="input", shape=[40, 40], dtype="float32") + res1 = dropout( + input, + p=0.3, + training=True, + mode='upscale_in_train', + rng_name='seed0') + res2 = dropout( + input, + p=0.3, + training=True, + mode='upscale_in_train', + rng_name='seed1') + res3 = dropout(input, p=0.3) + + in_np = np.random.random([40, 40]).astype("float32") + + exe = static.Executor(place) + res_list = [res1, res2] + for i in range(2): + out1, out2 = exe.run(static.default_main_program(), + feed={"input": in_np}, + fetch_list=res_list) + self.assertTrue(np.allclose(out1, out2)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 31704ebcd91..89c7be18a7d 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -619,7 +619,7 @@ class TestLookaheadOptimizer(unittest.TestCase): class TestRecomputeOptimizer(unittest.TestCase): - def net(self, return_input=False, with_dropout=False): + def net(self, return_input=False, with_dropout=False, with_seed=False): program = framework.Program() block = program.global_block() mul_x = block.create_parameter( @@ -628,7 +628,8 @@ class TestRecomputeOptimizer(unittest.TestCase): dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") mul_out = block.create_var( dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") - if with_dropout == True: + + if with_dropout is True: mul_out_drop = block.create_var( dtype="float32", shape=[5, 8], @@ -636,6 +637,10 @@ class TestRecomputeOptimizer(unittest.TestCase): name="mul.out.dropout") mul_out_mask = block.create_var( dtype="uint8", shape=[5, 8], lod_level=0, name="mul.out.mask") + if with_seed is True: + seed_out = block.create_var( + dtype="int32", shape=[1], name="seed.out") + b1 = block.create_parameter( dtype="float32", shape=[5, 8], lod_level=0, name="b1") b1_out = block.create_var( @@ -652,10 +657,23 @@ class TestRecomputeOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) - if with_dropout == True: + + if with_dropout is True: + dropout_inputs = {'X': [mul_out]} + if with_seed is True: + block.append_op( + type='seed', + outputs={'Out': seed_out}, + attrs={ + 'deterministic': True, + 'rng_name': 'rng0', + 'force_cpu': True + }) + dropout_inputs = {'X': [mul_out], 'Seed': [seed_out]} + block.append_op( type='dropout', - inputs={'X': [mul_out]}, + inputs=dropout_inputs, outputs={'Out': [mul_out_drop], 'Mask': [mul_out_mask]}, attrs={'dropout_prob': 0.5, }) @@ -670,6 +688,7 @@ class TestRecomputeOptimizer(unittest.TestCase): inputs={"X": mul_out, "Y": b1}, outputs={"Out": b1_out}) + block.append_op( type="elementwise_add", inputs={"X": b1_out, @@ -864,6 +883,27 @@ class TestRecomputeOptimizer(unittest.TestCase): "sgd", "sgd", "sgd" ]) + def test_dropout_with_determinate_seed(self): + mul_out, b1_out, b2_out, mean_out = self.net(with_dropout=True, + with_seed=True) + self.assertEqual(len(mean_out.block.ops), 6) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "seed", "dropout", "elementwise_add", "elementwise_add", + "mean" + ]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([b1_out]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 17) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "seed", "dropout", "elementwise_add", "elementwise_add", + "mean", "fill_constant", "mean_grad", "elementwise_add_grad", "mul", + "dropout", "elementwise_add_grad", "dropout_grad", "mul_grad", + "sgd", "sgd", "sgd" + ]) + def test_dropout_with_seed(self): """ when we recompute a dropout op, make sure that the recomputed one diff --git a/python/paddle/fluid/tests/unittests/test_seed_op.py b/python/paddle/fluid/tests/unittests/test_seed_op.py index 08478d7140d..0dcc197ece7 100644 --- a/python/paddle/fluid/tests/unittests/test_seed_op.py +++ b/python/paddle/fluid/tests/unittests/test_seed_op.py @@ -17,7 +17,10 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest -import paddle.fluid as fluid +import paddle +import paddle.static as static + +paddle.enable_static() class TestSeedOpFixSeed(OpTest): @@ -42,5 +45,32 @@ class TestSeedOpDiffSeed(OpTest): self.check_output(no_check_set=["Out"]) +class TestDropoutWithRandomSeedGenerator(unittest.TestCase): + def setUp(self): + paddle.framework.random.set_random_seed_generator('seed0', 123) + paddle.framework.random.set_random_seed_generator('seed1', 123) + self.rng0 = paddle.framework.random.get_random_seed_generator('seed0') + self.rng1 = paddle.framework.random.get_random_seed_generator('seed1') + self.places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + import paddle.distributed.fleet.meta_parallel.parallel_layers.random as random + with static.program_guard(static.Program(), static.Program()): + res1 = random.determinate_seed('seed0') + + exe = static.Executor(place) + res_list = [res1] + for i in range(2): + out1, = exe.run(static.default_main_program(), + fetch_list=res_list) + self.assertEqual(out1, np.cast['int32'](self.rng1.random())) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/framework/random.py b/python/paddle/framework/random.py index 701f8b5352c..a560072cf5a 100644 --- a/python/paddle/framework/random.py +++ b/python/paddle/framework/random.py @@ -122,3 +122,11 @@ def _manual_program_seed(seed): fluid.default_startup_program().random_seed = seed program = fluid.Program() program.global_seed(seed) + + +def set_random_seed_generator(name, seed): + core.set_random_seed_generator(name, seed) + + +def get_random_seed_generator(name): + return core.get_random_seed_generator(name) -- GitLab