提交 9c63fef6 编写于 作者: T tangwei12

random optimize

上级 5b9716d1
......@@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
double r = this->get_rand();
double r = this->getRandReal();
int idx = width - 1;
for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {
......@@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel<T> {
framework::TensorFromVector(ids, context.device_context(), output);
}
double get_rand() const {
private:
double getRandReal() const {
std::call_once(init_flag_, &SamplingIdKernel::getRndInstance);
return rnd();
}
static void getRndInstance() {
// Will be used to obtain a seed for the random number engine
std::random_device rd;
// Standard mersenne_twister_engine seeded with rd()
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(0, 1);
return dis(gen);
rnd = std::bind(dis, gen);
}
private:
unsigned int defaultSeed = 0;
static std::once_flag init_flag_;
static std::function<> rnd;
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册