未验证 提交 789743e1 编写于 作者: L Leo Chen 提交者: GitHub

use cuda generator in bernoulli cuda kernel (#30199)

上级 8696335f
...@@ -172,8 +172,7 @@ std::pair<uint64_t, uint64_t> Generator::IncrementOffset( ...@@ -172,8 +172,7 @@ std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Increment Offset only support in CUDA place")); "Increment Offset only support in CUDA place"));
#endif #endif
return std::make_pair(static_cast<int>(this->state_.current_seed), return std::make_pair(this->state_.current_seed, cur_offset);
cur_offset);
} }
void Generator::SetIsInitPy(bool is_init_py) { void Generator::SetIsInitPy(bool is_init_py) {
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/bernoulli_op.h" #include "paddle/fluid/operators/bernoulli_op.h"
...@@ -27,7 +28,10 @@ namespace operators { ...@@ -27,7 +28,10 @@ namespace operators {
template <typename T> template <typename T>
struct BernoulliCudaFunctor { struct BernoulliCudaFunctor {
unsigned int seed_; unsigned int seed_;
__host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {} unsigned int offset_;
__host__ __device__ BernoulliCudaFunctor(unsigned int seed,
unsigned int offset)
: seed_(seed), offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n, const T p) const { __host__ __device__ T operator()(const unsigned int n, const T p) const {
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
...@@ -37,7 +41,7 @@ struct BernoulliCudaFunctor { ...@@ -37,7 +41,7 @@ struct BernoulliCudaFunctor {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed_); rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0); thrust::uniform_real_distribution<T> dist(0.0, 1.0);
rng.discard(n); rng.discard(n + offset_);
return static_cast<T>(dist(rng) < p); return static_cast<T>(dist(rng) < p);
} }
}; };
...@@ -47,20 +51,26 @@ class BernoulliOpKernel<platform::CUDADeviceContext, T> ...@@ -47,20 +51,26 @@ class BernoulliOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
std::random_device rd;
auto seed = rd();
const auto x = ctx.Input<framework::Tensor>("X"); const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
auto* in_data = x->data<T>(); auto* in_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); auto* out_data = out->mutable_data<T>(ctx.GetPlace());
int64_t size = x->numel(); int64_t size = x->numel();
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
platform::Transform<platform::CUDADeviceContext> trans; platform::Transform<platform::CUDADeviceContext> trans;
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
auto* context = auto* context =
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context()); static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
trans(*context, index_sequence_begin, index_sequence_begin + size, in_data, trans(*context, index_sequence_begin, index_sequence_begin + size, in_data,
out_data, BernoulliCudaFunctor<T>(seed)); out_data,
BernoulliCudaFunctor<T>(static_cast<unsigned int>(seed_offset.first),
static_cast<unsigned int>(gen_offset)));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册