From 9c63fef63ca721a8e69c723314040fb9e9a5ad3d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 6 Aug 2018 22:01:46 +0800 Subject: [PATCH] random optimize --- paddle/fluid/operators/sampling_id_op.h | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 3d724e3ae72..7ad25fa13ae 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - double r = this->get_rand(); + double r = this->getRandReal(); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { @@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel { framework::TensorFromVector(ids, context.device_context(), output); } - double get_rand() const { + private: + double getRandReal() const { + std::call_once(init_flag_, &SamplingIdKernel::getRndInstance); + return rnd(); + } + + static void getRndInstance() { // Will be used to obtain a seed for the random number engine std::random_device rd; // Standard mersenne_twister_engine seeded with rd() std::mt19937 gen(rd()); std::uniform_real_distribution<> dis(0, 1); - return dis(gen); + rnd = std::bind(dis, gen); } - private: - unsigned int defaultSeed = 0; + static std::once_flag init_flag_; + static std::function<> rnd; }; } // namespace operators } // namespace paddle -- GitLab