diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index d51e97d98e902a87cd2a44d2019e93e8dfc30fc8..759a5754d9b6c47fe312d2654a4e13cce7af44c7 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -172,8 +172,7 @@ std::pair Generator::IncrementOffset( PADDLE_THROW(platform::errors::PermissionDenied( "Increment Offset only support in CUDA place")); #endif - return std::make_pair(static_cast(this->state_.current_seed), - cur_offset); + return std::make_pair(this->state_.current_seed, cur_offset); } void Generator::SetIsInitPy(bool is_init_py) { diff --git a/paddle/fluid/operators/bernoulli_op.cu b/paddle/fluid/operators/bernoulli_op.cu index 6565f5a9a2176972e9e5085c6646097e8349f259..5bdf20afe2006aec23d4e8fa2a1e79080f6ba7d2 100644 --- a/paddle/fluid/operators/bernoulli_op.cu +++ b/paddle/fluid/operators/bernoulli_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/bernoulli_op.h" @@ -27,7 +28,10 @@ namespace operators { template struct BernoulliCudaFunctor { unsigned int seed_; - __host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {} + unsigned int offset_; + __host__ __device__ BernoulliCudaFunctor(unsigned int seed, + unsigned int offset) + : seed_(seed), offset_(offset) {} __host__ __device__ T operator()(const unsigned int n, const T p) const { // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several @@ -37,7 +41,7 @@ struct BernoulliCudaFunctor { thrust::minstd_rand rng; rng.seed(seed_); thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(n); + rng.discard(n + offset_); return static_cast(dist(rng) < p); } }; @@ -47,20 +51,26 @@ class BernoulliOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - std::random_device rd; - auto seed = rd(); const auto x = ctx.Input("X"); auto out = ctx.Output("Out"); auto* in_data = x->data(); auto* out_data = out->mutable_data(ctx.GetPlace()); - int64_t size = x->numel(); - thrust::counting_iterator index_sequence_begin(0); + + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + auto seed_offset = gen_cuda->IncrementOffset(1); + int gen_offset = size * seed_offset.second; platform::Transform trans; + thrust::counting_iterator index_sequence_begin(0); auto* context = static_cast(&ctx.device_context()); + trans(*context, index_sequence_begin, index_sequence_begin + size, in_data, - out_data, BernoulliCudaFunctor(seed)); + out_data, + BernoulliCudaFunctor(static_cast(seed_offset.first), + static_cast(gen_offset))); } };