diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index a9950a48e0f42ce46ffd27ecc972c35d82a98fdb..60ad2efbe973b33485189f708e67da85d454330b 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -37,6 +37,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { DropoutOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dropout_prob", "Dropout probability.").SetDefault(.5f); + AddAttr("seed", "Dropout random seed.").SetDefault(0); AddInput("X", "The input of dropout op."); AddOutput("Out", "The output of dropout op."); AddOutput("Mask", "The dropout mask.").AsIntermediate(); @@ -75,7 +77,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, ops::DropoutOpGrad); -REGISTER_OP_CPU_KERNEL(dropout, - ops::DropoutKernel); +REGISTER_OP_CPU_KERNEL( + dropout, ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index 9e9efaa3b1af7e6905ea9ed565c7b40433010db7..c869ddf3e5481d88bb3b3e1142a4e3a23b7980a7 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -16,7 +16,7 @@ #include "paddle/operators/dropout_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(dropout, - ops::DropoutKernel); +REGISTER_OP_GPU_KERNEL( + dropout, ops::GPUDropoutKernel); REGISTER_OP_GPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h index d5d32df74b70c9d81b094d4f4ea6d6681b9573ba..becf89aca3b4f91d1dbe2501e1bf1548ebcf52c4 100644 --- a/paddle/operators/dropout_op.h +++ b/paddle/operators/dropout_op.h @@ -13,6 +13,11 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include +#include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -25,25 +30,85 @@ template ; template -class DropoutKernel : public framework::OpKernel { +class CPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + auto* mask = context.Output("Mask"); + T* mask_data = mask->mutable_data(context.GetPlace()); + T* y_data = y->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + + float dropout_prob = context.op_.GetAttr("dropout_prob"); + int seed = context.op_.GetAttr("seed"); + + std::minstd_rand engine; + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); + for (size_t i = 0; i < size; ++i) { + if (dist(engine) < dropout_prob) { + mask_data[i] = 0; + y_data[i] = 0; + } else { + mask_data[i] = 1; + y_data[i] = (1 - dropout_prob) * x_data[i]; + } + } + } +}; + +template +struct MaskGenerator { + float dropout_prob_; + int seed_; + + __host__ __device__ MaskGenerator(float dropout_prob, int seed) + : dropout_prob_(dropout_prob), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(0, 1); + rng.discard(n); + if (dist(rng) < dropout_prob_) { + return static_cast(0); + } else { + return static_cast(1); + } + } +}; + +// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); auto* mask = context.Output("Mask"); - mask->mutable_data(context.GetPlace()); y->mutable_data(context.GetPlace()); + float dropout_prob = context.op_.GetAttr("dropout_prob"); + int seed = context.op_.GetAttr("seed"); + thrust::counting_iterator index_sequence_begin(0); + int size = framework::product(mask->dims()); + T* mask_data = mask->mutable_data(context.GetPlace()); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(mask_data), + MaskGenerator(dropout_prob, seed)); + auto dims = x->dims(); - auto X = EigenMatrix::From(*x); - auto Y = EigenMatrix::From(*y); - auto M = EigenMatrix::From(*mask); + auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); + auto X = EigenMatrix::From(*x, new_dims); + auto Y = EigenMatrix::From(*y, new_dims); + auto M = EigenMatrix::From(*mask, new_dims); auto place = context.GetEigenDevice(); - M.device(place).setRandom(); - float dropout_prob = context.op_.GetAttr("dropout_prob"); - M.device(place) = (M > dropout_prob).cast(); - Y.device(place) = X * Y; + Y.device(place) = X * M * (1 - dropout_prob); } }; @@ -57,12 +122,15 @@ class DropoutGradKernel : public framework::OpKernel { grad_x->mutable_data(context.GetPlace()); auto dims = grad_x->dims(); - auto M = EigenMatrix::From(*mask); - auto dX = EigenMatrix::From(*grad_x); - auto dY = EigenMatrix::From(*grad_y); + int size = static_cast(framework::product(dims)); + auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); + auto M = EigenMatrix::From(*mask, new_dims); + auto dX = EigenMatrix::From(*grad_x, new_dims); + auto dY = EigenMatrix::From(*grad_y, new_dims); auto place = context.GetEigenDevice(); - dX.device(place) = dY * M; + float dropout_prob = context.op_.GetAttr("dropout_prob"); + dX.device(place) = dY * M * (1 - dropout_prob); } }; diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 661ebd89648feec77367c278e5f045b8238e1dc1..850910363d5543053bc7e1cc5bb9431be1cb8b06 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -4,6 +4,7 @@ py_test(test_scope SRCS test_scope.py) py_test(test_tensor SRCS test_tensor.py) py_test(test_mul_op SRCS test_mul_op.py) +py_test(test_dropout_op SRCS test_dropout_op.py) py_test(test_mean_op SRCS test_mean_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 3bc05a0feccbbd3d5e7852d85bd3dc8edaccfd07..a4899355b53d62903b97999ebf9c2c7ecfc6c4cd 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -6,13 +6,13 @@ from paddle.v2.framework.op import Operator class OpTestMeta(type): """ Operator Test ClassMeta. - - It injects `test_all` method into user's OperatorTest class, to make Python + + It injects `test_all` method into user's OperatorTest class, to make Python unittest module run that method. - + The `test_all` read what value is stored in `self`. It use self's values to create and run a operator, and check whether that op is OK or not. - + See `test_add_two_op` for example usage. """ diff --git a/python/paddle/v2/framework/tests/test_dropout_op.py b/python/paddle/v2/framework/tests/test_dropout_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4738f6145187f06b093fdaf7ee3aa6ef9410d0 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_dropout_op.py @@ -0,0 +1,42 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta + + +class TestDropoutOpProbZero(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.0} + self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))} + + +class TestDropoutOpAllProbOne(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 1.0} + self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))} + + +class DropoutGradOpTest(GradientChecker): + def test_dropout_2d(self): + op = create_op("dropout") + inputs = {'X': np.random.random((10, 5)).astype("float32")} + self.compare_grad(op, inputs) + self.check_grad(op, inputs, set(["X"]), "Out") + + def test_dropout_3d(self): + op = create_op("dropout") + inputs = {'X': np.random.random((10, 5, 4)).astype("float32")} + self.compare_grad(op, inputs) + self.check_grad(op, inputs, set(["X"]), "Out") + + +if __name__ == '__main__': + unittest.main()