提交 9c63fef6 编写于 作者: T tangwei12

random optimize

上级 5b9716d1
...@@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ids(batch_size); std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
double r = this->get_rand(); double r = this->getRandReal();
int idx = width - 1; int idx = width - 1;
for (int j = 0; j < width; ++j) { for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) { if ((r -= ins_vector[i * width + j]) < 0) {
...@@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel<T> {
framework::TensorFromVector(ids, context.device_context(), output); framework::TensorFromVector(ids, context.device_context(), output);
} }
double get_rand() const { private:
double getRandReal() const {
std::call_once(init_flag_, &SamplingIdKernel::getRndInstance);
return rnd();
}
static void getRndInstance() {
// Will be used to obtain a seed for the random number engine // Will be used to obtain a seed for the random number engine
std::random_device rd; std::random_device rd;
// Standard mersenne_twister_engine seeded with rd() // Standard mersenne_twister_engine seeded with rd()
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(0, 1); std::uniform_real_distribution<> dis(0, 1);
return dis(gen); rnd = std::bind(dis, gen);
} }
private: static std::once_flag init_flag_;
unsigned int defaultSeed = 0; static std::function<> rnd;
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册