diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index c869ddf3e5481d88bb3b3e1142a4e3a23b7980a7..ccee7cfa7ac74694515234847ec99237d3a7c8fd 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -13,8 +13,72 @@ limitations under the License. */ #define EIGEN_USE_GPU +#include +#include +#include +#include #include "paddle/operators/dropout_op.h" +namespace paddle { +namespace operators { + +template +struct MaskGenerator { + float dropout_prob; + int seed; + + __host__ __device__ MaskGenerator(float dropout_prob, int seed) + : dropout_prob(dropout_prob), seed(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed); + thrust::uniform_real_distribution dist(0, 1); + rng.discard(n); + if (dist(rng) < dropout_prob) { + return static_cast(0); + } else { + return static_cast(1); + } + } +}; + +// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + auto* mask = context.Output("Mask"); + y->mutable_data(context.GetPlace()); + + float dropout_prob = context.Attr("dropout_prob"); + int seed = context.Attr("seed"); + thrust::counting_iterator index_sequence_begin(0); + int size = framework::product(mask->dims()); + T* mask_data = mask->mutable_data(context.GetPlace()); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(mask_data), + MaskGenerator(dropout_prob, seed)); + + auto dims = x->dims(); + auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); + auto X = EigenMatrix::From(*x, new_dims); + auto Y = EigenMatrix::From(*y, new_dims); + auto M = EigenMatrix::From(*mask, new_dims); + + auto place = context.GetEigenDevice(); + Y.device(place) = X * M; + // TODO(xinghai-sun): add test time logits. + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( dropout, ops::GPUDropoutKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h index 8f4363bcb879ffa62ac3e5d8c7157458edb2fbda..c9e45fa22038c0050431ccfda35cb63261c2978c 100644 --- a/paddle/operators/dropout_op.h +++ b/paddle/operators/dropout_op.h @@ -13,10 +13,6 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include #include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -60,60 +56,6 @@ class CPUDropoutKernel : public framework::OpKernel { } }; -template -struct MaskGenerator { - float dropout_prob; - int seed; - - __host__ __device__ MaskGenerator(float dropout_prob, int seed) - : dropout_prob(dropout_prob), seed(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); - rng.discard(n); - if (dist(rng) < dropout_prob) { - return static_cast(0); - } else { - return static_cast(1); - } - } -}; - -// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class GPUDropoutKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* y = context.Output("Out"); - auto* mask = context.Output("Mask"); - y->mutable_data(context.GetPlace()); - - float dropout_prob = context.Attr("dropout_prob"); - int seed = context.Attr("seed"); - thrust::counting_iterator index_sequence_begin(0); - int size = framework::product(mask->dims()); - T* mask_data = mask->mutable_data(context.GetPlace()); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(mask_data), - MaskGenerator(dropout_prob, seed)); - - auto dims = x->dims(); - auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); - auto X = EigenMatrix::From(*x, new_dims); - auto Y = EigenMatrix::From(*y, new_dims); - auto M = EigenMatrix::From(*mask, new_dims); - - auto place = context.GetEigenDevice(); - Y.device(place) = X * M; - // TODO: add test time logits. - } -}; - template class DropoutGradKernel : public framework::OpKernel { public: