diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 1074ed6acc22a81f46c466d917ef973945a12898..e4436549f6185ba04a5f270893596a6dcb11e89b 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel { } }; -template class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -73,7 +72,6 @@ are set equal to their corresponding inputs. } }; -template class DropoutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, - ops::DropoutOpGrad); +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, - ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index c949968a744089ea756b48d2a6ee092be90d25d9..f6c85a2a537b37feb20e6d62729dc5075af2a5d9 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -23,13 +23,13 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template __global__ void RandomGenerator(const size_t n, const int seed, - const AttrType dropout_prob, const T* src, + const float dropout_prob, const T* src, T* mask_data, T* dst) { thrust::minstd_rand rng; rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); + thrust::uniform_real_distribution dist(0, 1); int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < n; idx += blockDim.x * gridDim.x) { @@ -45,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed, // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. -template +template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); - AttrType dropout_prob = context.Attr("dropout_prob")); + float dropout_prob = context.Attr("dropout_prob"); auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); @@ -71,8 +71,8 @@ class GPUDropoutKernel : public framework::OpKernel { int threads = 512; int grid = (x->numel() + threads - 1) / threads; - RandomGenerator<<>>( + RandomGenerator< + T><<>>( size, seed, dropout_prob, x_data, mask_data, y_data); } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); @@ -86,7 +86,7 @@ class GPUDropoutKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel); + dropout, ops::GPUDropoutKernel, + ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL(dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 209e4dec1756dc65fbf147c4dbbf0913d3c6ef7e..b5ee86ae2d11dfc835e1a3a6826ce016baf38a29 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -25,7 +25,7 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 6fcd5ac1a6665f37980beecda363ccee98d7bbbc..5e2c460c41e45b5edb75567aa57278714346edbd 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest