diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 3d724e3ae726fb5b91843d603d434103ae6201d8..7ad25fa13ae8ace963dba5c66a57118e4267737e 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - double r = this->get_rand(); + double r = this->getRandReal(); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { @@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel { framework::TensorFromVector(ids, context.device_context(), output); } - double get_rand() const { + private: + double getRandReal() const { + std::call_once(init_flag_, &SamplingIdKernel::getRndInstance); + return rnd(); + } + + static void getRndInstance() { // Will be used to obtain a seed for the random number engine std::random_device rd; // Standard mersenne_twister_engine seeded with rd() std::mt19937 gen(rd()); std::uniform_real_distribution<> dis(0, 1); - return dis(gen); + rnd = std::bind(dis, gen); } - private: - unsigned int defaultSeed = 0; + static std::once_flag init_flag_; + static std::function<> rnd; }; } // namespace operators } // namespace paddle