From 21d98351e51c53ca55d10b946ea0e7e95726896b Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 24 Nov 2020 10:57:47 +0800 Subject: [PATCH] fix shuffle batch op shuffle (#28533) (#28765) --- paddle/fluid/operators/shuffle_batch_op.h | 39 ++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/shuffle_batch_op.h b/paddle/fluid/operators/shuffle_batch_op.h index ad3fab0bdbc..ac8e3f0538f 100644 --- a/paddle/fluid/operators/shuffle_batch_op.h +++ b/paddle/fluid/operators/shuffle_batch_op.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "glog/logging.h" #include "paddle/fluid/framework/eigen.h" @@ -67,7 +68,43 @@ class ShuffleBatchKernel : public framework::OpKernel { } std::default_random_engine engine; engine.seed(seed_int); - std::shuffle(idx_vec.begin(), idx_vec.end(), engine); + + auto custom_random_shuffle = [&idx_vec]() { + std::random_device rnd; + int64_t seed_tmp = rnd(); + std::default_random_engine rng(seed_tmp); + const int n = idx_vec.size(); + std::vector v(n); + std::iota(v.begin(), v.end(), 0); + std::vector visit(n, false); + while (!v.empty()) { + std::shuffle(v.begin(), v.end(), rng); + int tmp = v.back(); + v.pop_back(); + if (v.empty()) { + std::uniform_int_distribution distr(0, n - 2); + idx_vec[tmp] = tmp; + std::swap(idx_vec[tmp], idx_vec[(distr(rng) + tmp + 1) % n]); + return; + } + visit[tmp] = true; + std::shuffle(v.begin(), v.end(), rng); + int curr = v.back(); + v.pop_back(); + v.push_back(tmp); + idx_vec[tmp] = curr; + while (!visit[curr]) { + visit[curr] = true; + std::shuffle(v.begin(), v.end(), rng); + idx_vec[curr] = v.back(); + v.pop_back(); + curr = idx_vec[curr]; + } + } + }; + custom_random_shuffle(); + // change shuffle to custom_random_shuffle + // std::shuffle(idx_vec.begin(), idx_vec.end(), engine); // ShuffleIdx record shuffle order shuffleidx->Resize(framework::make_ddim({(int64_t)idx_vec.size()})); -- GitLab