Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
08b62f49
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
08b62f49
编写于
11月 20, 2020
作者:
Y
yaoxuefeng
提交者:
GitHub
11月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix shuffle batch op shuffle (#28533)
上级
d3d1a6b6
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
38 addition
and
1 deletion
+38
-1
paddle/fluid/operators/shuffle_batch_op.h
paddle/fluid/operators/shuffle_batch_op.h
+38
-1
未找到文件。
paddle/fluid/operators/shuffle_batch_op.h
浏览文件 @
08b62f49
...
...
@@ -19,6 +19,7 @@
#include <ctime>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/eigen.h"
...
...
@@ -67,7 +68,43 @@ class ShuffleBatchKernel : public framework::OpKernel<T> {
}
std
::
default_random_engine
engine
;
engine
.
seed
(
seed_int
);
std
::
shuffle
(
idx_vec
.
begin
(),
idx_vec
.
end
(),
engine
);
auto
custom_random_shuffle
=
[
&
idx_vec
]()
{
std
::
random_device
rnd
;
int64_t
seed_tmp
=
rnd
();
std
::
default_random_engine
rng
(
seed_tmp
);
const
int
n
=
idx_vec
.
size
();
std
::
vector
<
int
>
v
(
n
);
std
::
iota
(
v
.
begin
(),
v
.
end
(),
0
);
std
::
vector
<
bool
>
visit
(
n
,
false
);
while
(
!
v
.
empty
())
{
std
::
shuffle
(
v
.
begin
(),
v
.
end
(),
rng
);
int
tmp
=
v
.
back
();
v
.
pop_back
();
if
(
v
.
empty
())
{
std
::
uniform_int_distribution
<
int
>
distr
(
0
,
n
-
2
);
idx_vec
[
tmp
]
=
tmp
;
std
::
swap
(
idx_vec
[
tmp
],
idx_vec
[(
distr
(
rng
)
+
tmp
+
1
)
%
n
]);
return
;
}
visit
[
tmp
]
=
true
;
std
::
shuffle
(
v
.
begin
(),
v
.
end
(),
rng
);
int
curr
=
v
.
back
();
v
.
pop_back
();
v
.
push_back
(
tmp
);
idx_vec
[
tmp
]
=
curr
;
while
(
!
visit
[
curr
])
{
visit
[
curr
]
=
true
;
std
::
shuffle
(
v
.
begin
(),
v
.
end
(),
rng
);
idx_vec
[
curr
]
=
v
.
back
();
v
.
pop_back
();
curr
=
idx_vec
[
curr
];
}
}
};
custom_random_shuffle
();
// change shuffle to custom_random_shuffle
// std::shuffle(idx_vec.begin(), idx_vec.end(), engine);
// ShuffleIdx record shuffle order
shuffleidx
->
Resize
(
framework
::
make_ddim
({(
int64_t
)
idx_vec
.
size
()}));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录