未验证 提交 21d98351 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix shuffle batch op shuffle (#28533) (#28765)

上级 5d7e5e35
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <ctime> #include <ctime>
#include <random> #include <random>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
...@@ -67,7 +68,43 @@ class ShuffleBatchKernel : public framework::OpKernel<T> { ...@@ -67,7 +68,43 @@ class ShuffleBatchKernel : public framework::OpKernel<T> {
} }
std::default_random_engine engine; std::default_random_engine engine;
engine.seed(seed_int); engine.seed(seed_int);
std::shuffle(idx_vec.begin(), idx_vec.end(), engine);
auto custom_random_shuffle = [&idx_vec]() {
std::random_device rnd;
int64_t seed_tmp = rnd();
std::default_random_engine rng(seed_tmp);
const int n = idx_vec.size();
std::vector<int> v(n);
std::iota(v.begin(), v.end(), 0);
std::vector<bool> visit(n, false);
while (!v.empty()) {
std::shuffle(v.begin(), v.end(), rng);
int tmp = v.back();
v.pop_back();
if (v.empty()) {
std::uniform_int_distribution<int> distr(0, n - 2);
idx_vec[tmp] = tmp;
std::swap(idx_vec[tmp], idx_vec[(distr(rng) + tmp + 1) % n]);
return;
}
visit[tmp] = true;
std::shuffle(v.begin(), v.end(), rng);
int curr = v.back();
v.pop_back();
v.push_back(tmp);
idx_vec[tmp] = curr;
while (!visit[curr]) {
visit[curr] = true;
std::shuffle(v.begin(), v.end(), rng);
idx_vec[curr] = v.back();
v.pop_back();
curr = idx_vec[curr];
}
}
};
custom_random_shuffle();
// change shuffle to custom_random_shuffle
// std::shuffle(idx_vec.begin(), idx_vec.end(), engine);
// ShuffleIdx record shuffle order // ShuffleIdx record shuffle order
shuffleidx->Resize(framework::make_ddim({(int64_t)idx_vec.size()})); shuffleidx->Resize(framework::make_ddim({(int64_t)idx_vec.size()}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册