Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a6794926
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a6794926
编写于
4月 27, 2022
作者:
zhouweiwei2014
提交者:
GitHub
4月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix randperm out of bound bug (#42057)
上级
b20683c0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
21 addition
and
18 deletion
+21
-18
paddle/phi/kernels/gpu/randperm_kernel.cu
paddle/phi/kernels/gpu/randperm_kernel.cu
+21
-18
未找到文件。
paddle/phi/kernels/gpu/randperm_kernel.cu
浏览文件 @
a6794926
...
...
@@ -36,26 +36,29 @@ DECLARE_bool(use_curand);
namespace
phi
{
template
<
typename
T
>
__global__
void
SwapRepeatKernel
(
int
*
key
,
T
*
data
,
int
n
,
uint64_t
seed
,
uint64_t
offset
)
{
template
<
typename
keyT
,
typename
dataT
>
__global__
void
SwapRepeatKernel
(
keyT
*
key_out_data
,
dataT
*
out_data
,
int
n
,
uint64_t
seed
,
uint64_t
offset
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
if
(
idx
<
n
)
return
;
if
(
idx
>=
n
-
1
)
return
;
// out of range
bool
first_repeat
=
false
;
if
(
data
[
idx
]
==
data
[
idx
+
1
])
{
bool
is_
first_repeat
=
false
;
if
(
key_out_data
[
idx
]
==
key_out_
data
[
idx
+
1
])
{
if
(
idx
==
0
)
{
first_repeat
=
true
;
}
else
if
(
data
[
idx
]
!=
data
[
idx
-
1
])
{
first_repeat
=
true
;
is_
first_repeat
=
true
;
}
else
if
(
key_out_data
[
idx
]
!=
key_out_
data
[
idx
-
1
])
{
is_
first_repeat
=
true
;
}
}
if
(
!
first_repeat
)
return
;
if
(
!
is_
first_repeat
)
return
;
int
repeat_size
=
1
;
for
(
int
i
=
idx
;
i
<
n
;
++
i
)
{
if
(
data
[
i
]
==
data
[
i
+
1
])
{
if
(
key_out_data
[
i
]
==
key_out_
data
[
i
+
1
])
{
++
repeat_size
;
}
else
{
break
;
...
...
@@ -74,9 +77,9 @@ __global__ void SwapRepeatKernel(
uint32_t
r
=
hiprand
(
&
state
)
%
(
i
+
1
);
#endif
if
(
r
!=
i
)
{
T
tmp
=
data
[
idx
+
i
];
data
[
idx
+
i
]
=
data
[
idx
+
r
];
data
[
idx
+
r
]
=
tmp
;
dataT
tmp
=
out_
data
[
idx
+
i
];
out_data
[
idx
+
i
]
=
out_
data
[
idx
+
r
];
out_
data
[
idx
+
r
]
=
tmp
;
}
}
}
...
...
@@ -138,10 +141,10 @@ void RandpermRawKernel(
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
n
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
);
SwapRepeatKernel
<
T
><
<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
SwapRepeatKernel
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
key_out
.
data
<
int
>
(),
out_data
,
n
,
seed_offset
.
first
,
seed_offset
.
second
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录