From 789743e1905e30e90af315fa7a90e94378de7d6a Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 8 Jan 2021 19:13:35 +0800 Subject: [PATCH] use cuda generator in bernoulli cuda kernel (#30199) --- paddle/fluid/framework/generator.cc | 3 +-- paddle/fluid/operators/bernoulli_op.cu | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index d51e97d98e..759a5754d9 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -172,8 +172,7 @@ std::pair Generator::IncrementOffset( PADDLE_THROW(platform::errors::PermissionDenied( "Increment Offset only support in CUDA place")); #endif - return std::make_pair(static_cast(this->state_.current_seed), - cur_offset); + return std::make_pair(this->state_.current_seed, cur_offset); } void Generator::SetIsInitPy(bool is_init_py) { diff --git a/paddle/fluid/operators/bernoulli_op.cu b/paddle/fluid/operators/bernoulli_op.cu index 6565f5a9a2..5bdf20afe2 100644 --- a/paddle/fluid/operators/bernoulli_op.cu +++ b/paddle/fluid/operators/bernoulli_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/bernoulli_op.h" @@ -27,7 +28,10 @@ namespace operators { template struct BernoulliCudaFunctor { unsigned int seed_; - __host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {} + unsigned int offset_; + __host__ __device__ BernoulliCudaFunctor(unsigned int seed, + unsigned int offset) + : seed_(seed), offset_(offset) {} __host__ __device__ T operator()(const unsigned int n, const T p) const { // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several @@ -37,7 +41,7 @@ struct BernoulliCudaFunctor { thrust::minstd_rand rng; rng.seed(seed_); thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(n); + rng.discard(n + offset_); return static_cast(dist(rng) < p); } }; @@ -47,20 +51,26 @@ class BernoulliOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - std::random_device rd; - auto seed = rd(); const auto x = ctx.Input("X"); auto out = ctx.Output("Out"); auto* in_data = x->data(); auto* out_data = out->mutable_data(ctx.GetPlace()); - int64_t size = x->numel(); - thrust::counting_iterator index_sequence_begin(0); + + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + auto seed_offset = gen_cuda->IncrementOffset(1); + int gen_offset = size * seed_offset.second; platform::Transform trans; + thrust::counting_iterator index_sequence_begin(0); auto* context = static_cast(&ctx.device_context()); + trans(*context, index_sequence_begin, index_sequence_begin + size, in_data, - out_data, BernoulliCudaFunctor(seed)); + out_data, + BernoulliCudaFunctor(static_cast(seed_offset.first), + static_cast(gen_offset))); } }; -- GitLab