未验证 提交 789743e1 编写于 作者: L Leo Chen 提交者: GitHub

use cuda generator in bernoulli cuda kernel (#30199)

上级 8696335f
......@@ -172,8 +172,7 @@ std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
PADDLE_THROW(platform::errors::PermissionDenied(
"Increment Offset only support in CUDA place"));
#endif
return std::make_pair(static_cast<int>(this->state_.current_seed),
cur_offset);
return std::make_pair(this->state_.current_seed, cur_offset);
}
void Generator::SetIsInitPy(bool is_init_py) {
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/bernoulli_op.h"
......@@ -27,7 +28,10 @@ namespace operators {
template <typename T>
struct BernoulliCudaFunctor {
unsigned int seed_;
__host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {}
unsigned int offset_;
__host__ __device__ BernoulliCudaFunctor(unsigned int seed,
unsigned int offset)
: seed_(seed), offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n, const T p) const {
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
......@@ -37,7 +41,7 @@ struct BernoulliCudaFunctor {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
rng.discard(n);
rng.discard(n + offset_);
return static_cast<T>(dist(rng) < p);
}
};
......@@ -47,20 +51,26 @@ class BernoulliOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::random_device rd;
auto seed = rd();
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto* in_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
int64_t size = x->numel();
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
platform::Transform<platform::CUDADeviceContext> trans;
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
auto* context =
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
trans(*context, index_sequence_begin, index_sequence_begin + size, in_data,
out_data, BernoulliCudaFunctor<T>(seed));
out_data,
BernoulliCudaFunctor<T>(static_cast<unsigned int>(seed_offset.first),
static_cast<unsigned int>(gen_offset)));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册