未验证 提交 8cc8e411 编写于 作者: W WangXi 提交者: GitHub

[hybrid] static model parallel dropout support deterministic RandomSeedGenerator (#36228)

上级 d89a759b
......@@ -63,6 +63,43 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator;
}
using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;
static RNGMap& GetRandomSeedGeneratorMap() {
static auto random_seed_generator_map = RNGMap();
return random_seed_generator_map;
}
const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter == rng_map.end(), true,
platform::errors::AlreadyExists(
"%s RandomSeedGenerator is already exist", name));
auto generator = std::make_shared<Generator>(seed);
bool emplace_success = rng_map.emplace(name, generator).second;
PADDLE_ENFORCE_EQ(
emplace_success, true,
platform::errors::PermissionDenied(
"SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator",
name));
return rng_map[name];
}
const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter != rng_map.end(), true,
platform::errors::NotFound(
"%s RandomSeedGenerator is not found, please "
"use `set_random_seed_generator` to set rng first",
name));
return iter->second;
}
std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine() {
static auto op_default_cpu_engine = std::make_shared<std::mt19937_64>();
return op_default_cpu_engine;
......
......@@ -126,5 +126,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);
const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed);
const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name);
} // namespace framework
} // namespace paddle
......@@ -29,7 +29,7 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if ((seed) && platform::is_gpu_place(seed->place())) {
if (seed) {
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
......@@ -39,12 +39,8 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
*seed_data = seed_offset.first;
*increment = seed_offset.second;
} else {
if (seed) {
*seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
}
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
*increment = offset;
}
}
......
......@@ -39,6 +39,17 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddOutput("Out", "The output of seed op.");
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<bool>("deterministic",
"(bool, default false) Whether to use deterministic "
"RandomSeedGenerator which "
"generate by `set_random_seed_generator`")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"rng_name",
"use deterministic RandomSeedGenerator which name is `rng_name`")
.SetDefault("")
.AsExtra();
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
......
......@@ -23,16 +23,9 @@ class GPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<Tensor>("Out");
int user_seed = context.Attr<int>("seed");
auto force_cpu = context.Attr<bool>("force_cpu");
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
seed = rnd();
}
int seed = get_seed(context);
auto force_cpu = context.Attr<bool>("force_cpu");
bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace();
if (cpu_place) {
platform::DeviceContextPool &pool =
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
......@@ -20,24 +21,37 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
int user_seed = context.Attr<int>("seed");
static int get_seed(const framework::ExecutionContext& context) {
int user_seed = context.Attr<int>("seed");
bool deterministic = context.Attr<bool>("deterministic");
int seed = 0;
if (!deterministic) {
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
std::random_device rnd;
seed = rnd();
}
out_data[0] = seed;
} else {
std::string name = context.Attr<std::string>("rng_name");
auto rng = framework::GetRandomSeedGenerator(name);
do { // NOTE(wangxi): cpu dropout will use random seed if seed == 0
seed = static_cast<int>(rng->Random64());
} while (seed == 0);
}
return seed;
}
template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
out_data[0] = get_seed(context);
}
};
......
......@@ -60,6 +60,8 @@ void BindGenerator(py::module* m_ptr) {
&framework::Generator::SetIsInitPy);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator);
m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator);
}
} // namespace pybind
} // namespace paddle
......@@ -15,6 +15,11 @@
import paddle
import contextlib
import numpy as np
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, default_main_program
from paddle.fluid.layer_helper import LayerHelper
__all__ = []
......@@ -93,3 +98,135 @@ def model_parallel_random_seed(seed=None):
RNG_STATE_TRACKER.reset()
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)
def determinate_seed(rng_name):
assert rng_name is not None and rng_name != ""
helper = LayerHelper('seed', **locals())
out = helper.create_variable_for_type_inference(dtype=paddle.int32)
# set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
helper.append_op(
type='seed',
outputs={'Out': out},
attrs={'deterministic': True,
'rng_name': rng_name,
'force_cpu': True})
return out
def dropout(x,
p=0.5,
axis=None,
rng_name=None,
training=True,
mode="upscale_in_train",
name=None):
"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co-adaption during training. The dropout operator randomly sets the
outputs of some units to zero, while upscale others according to the given
dropout probability.
Args:
x (Tensor): The input tensor. The data type is float32 or float64.
p (float|int): Probability of setting units to zero. Default 0.5.
axis (int|list|tuple): The axis along which the dropout is performed. Default None.
rng_name (str): The random seed generator name, which used to obtain deterministic results.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - dropout_prob )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - dropout_prob)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the dropout, has same shape and data type as `x` .
Examples:
We use ``p=0.5`` in the following description for simplicity.
1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
.. code-block:: text
Let's see a simple case when x is a 2d tensor with shape 2*3:
[[1 2 3]
[4 5 6]]
we generate mask with the same shape as x, which is 2*3. The value of mask is
sampled from a Bernoulli distribution randomly. For example, we may get such mask:
[[0 1 0]
[1 0 1]]
So the output is obtained from elementwise multiply of x and mask:
[[0 2 0]
[4 0 6]]
Using default setting, i.e. ``mode='upscale_in_train'`` ,
if in training phase, the final upscale output is:
[[0 4 0 ]
[8 0 12]]
if in test phase, the output is the same as input:
[[1 2 3]
[4 5 6]]
we can also set ``mode='downscale_in_infer'`` , then
if in training phase, the final output is:
[[0 2 0]
[4 0 6]]
if in test phase, the scale output is:
[[0.5 1. 1.5]
[2. 2.5 3. ]]
"""
if rng_name is None:
return paddle.nn.functional.dropout(x, p, axis, training, mode, name)
# fast return for p == 0
if p == 0: return x
assert isinstance(p, (float, int)), \
TypeError("p argument should be a number")
assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
assert mode in ('downscale_in_infer', 'upscale_in_train'), \
ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
assert axis is None, \
TypeError("unsupport axis when using random seed generator")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
# dygraph using tracker, doesn't need determinate seed
if in_dygraph_mode():
out, mask = _C_ops.dropout(x, 'dropout_prob', p, 'is_test',
not training, 'fix_seed', False, 'seed', 0,
'dropout_implementation', mode)
return out
seed = determinate_seed(rng_name)
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'dropout')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
helper.append_op(
type='dropout',
inputs={'X': [x],
'Seed': seed},
outputs={'Out': [out],
'Mask': [mask]},
attrs={
'dropout_prob': p,
'is_test': not training,
'dropout_implementation': mode,
})
return out
......@@ -175,11 +175,15 @@ class ProgramStats(object):
return
op_idx = 0
while (op_idx < len(self.ops)):
while op_idx < len(self.ops):
op = self.ops[op_idx]
if op.desc.type() != "dropout":
op_idx += 1
continue
# already insert seed op before dropout
if op.input('Seed') is not None and len(op.input('Seed')) == 1:
op_idx += 1
continue
# add a seed op so that the two dropout op can generate same output
op_unique_name = unique_name.generate("seed")
var_unique_name = unique_name.generate_with_ignorable_key(".".join(
......
......@@ -19,6 +19,7 @@ import numpy as np
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.static as static
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
......@@ -856,5 +857,48 @@ class TestAlphaDropoutCAPI(unittest.TestCase):
self.assertTrue(np.allclose(result.numpy(), result_np))
class TestDropoutWithDeterminateSeedGenerator(unittest.TestCase):
def setUp(self):
paddle.framework.random.set_random_seed_generator('seed0', 123)
paddle.framework.random.set_random_seed_generator('seed1', 123)
rng0 = paddle.framework.random.get_random_seed_generator('seed0')
rng1 = paddle.framework.random.get_random_seed_generator('seed1')
self.places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def check_static_result(self, place):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import dropout
with static.program_guard(static.Program(), static.Program()):
input = static.data(name="input", shape=[40, 40], dtype="float32")
res1 = dropout(
input,
p=0.3,
training=True,
mode='upscale_in_train',
rng_name='seed0')
res2 = dropout(
input,
p=0.3,
training=True,
mode='upscale_in_train',
rng_name='seed1')
res3 = dropout(input, p=0.3)
in_np = np.random.random([40, 40]).astype("float32")
exe = static.Executor(place)
res_list = [res1, res2]
for i in range(2):
out1, out2 = exe.run(static.default_main_program(),
feed={"input": in_np},
fetch_list=res_list)
self.assertTrue(np.allclose(out1, out2))
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
if __name__ == '__main__':
unittest.main()
......@@ -619,7 +619,7 @@ class TestLookaheadOptimizer(unittest.TestCase):
class TestRecomputeOptimizer(unittest.TestCase):
def net(self, return_input=False, with_dropout=False):
def net(self, return_input=False, with_dropout=False, with_seed=False):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
......@@ -628,7 +628,8 @@ class TestRecomputeOptimizer(unittest.TestCase):
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
if with_dropout == True:
if with_dropout is True:
mul_out_drop = block.create_var(
dtype="float32",
shape=[5, 8],
......@@ -636,6 +637,10 @@ class TestRecomputeOptimizer(unittest.TestCase):
name="mul.out.dropout")
mul_out_mask = block.create_var(
dtype="uint8", shape=[5, 8], lod_level=0, name="mul.out.mask")
if with_seed is True:
seed_out = block.create_var(
dtype="int32", shape=[1], name="seed.out")
b1 = block.create_parameter(
dtype="float32", shape=[5, 8], lod_level=0, name="b1")
b1_out = block.create_var(
......@@ -652,10 +657,23 @@ class TestRecomputeOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
if with_dropout == True:
if with_dropout is True:
dropout_inputs = {'X': [mul_out]}
if with_seed is True:
block.append_op(
type='seed',
outputs={'Out': seed_out},
attrs={
'deterministic': True,
'rng_name': 'rng0',
'force_cpu': True
})
dropout_inputs = {'X': [mul_out], 'Seed': [seed_out]}
block.append_op(
type='dropout',
inputs={'X': [mul_out]},
inputs=dropout_inputs,
outputs={'Out': [mul_out_drop],
'Mask': [mul_out_mask]},
attrs={'dropout_prob': 0.5, })
......@@ -670,6 +688,7 @@ class TestRecomputeOptimizer(unittest.TestCase):
inputs={"X": mul_out,
"Y": b1},
outputs={"Out": b1_out})
block.append_op(
type="elementwise_add",
inputs={"X": b1_out,
......@@ -864,6 +883,27 @@ class TestRecomputeOptimizer(unittest.TestCase):
"sgd", "sgd", "sgd"
])
def test_dropout_with_determinate_seed(self):
mul_out, b1_out, b2_out, mean_out = self.net(with_dropout=True,
with_seed=True)
self.assertEqual(len(mean_out.block.ops), 6)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "seed", "dropout", "elementwise_add", "elementwise_add",
"mean"
])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b1_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 17)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "seed", "dropout", "elementwise_add", "elementwise_add",
"mean", "fill_constant", "mean_grad", "elementwise_add_grad", "mul",
"dropout", "elementwise_add_grad", "dropout_grad", "mul_grad",
"sgd", "sgd", "sgd"
])
def test_dropout_with_seed(self):
"""
when we recompute a dropout op, make sure that the recomputed one
......
......@@ -17,7 +17,10 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle
import paddle.static as static
paddle.enable_static()
class TestSeedOpFixSeed(OpTest):
......@@ -42,5 +45,32 @@ class TestSeedOpDiffSeed(OpTest):
self.check_output(no_check_set=["Out"])
class TestDropoutWithRandomSeedGenerator(unittest.TestCase):
def setUp(self):
paddle.framework.random.set_random_seed_generator('seed0', 123)
paddle.framework.random.set_random_seed_generator('seed1', 123)
self.rng0 = paddle.framework.random.get_random_seed_generator('seed0')
self.rng1 = paddle.framework.random.get_random_seed_generator('seed1')
self.places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def check_static_result(self, place):
import paddle.distributed.fleet.meta_parallel.parallel_layers.random as random
with static.program_guard(static.Program(), static.Program()):
res1 = random.determinate_seed('seed0')
exe = static.Executor(place)
res_list = [res1]
for i in range(2):
out1, = exe.run(static.default_main_program(),
fetch_list=res_list)
self.assertEqual(out1, np.cast['int32'](self.rng1.random()))
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
if __name__ == '__main__':
unittest.main()
......@@ -122,3 +122,11 @@ def _manual_program_seed(seed):
fluid.default_startup_program().random_seed = seed
program = fluid.Program()
program.global_seed(seed)
def set_random_seed_generator(name, seed):
core.set_random_seed_generator(name, seed)
def get_random_seed_generator(name):
return core.get_random_seed_generator(name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册