Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8f7c02f2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
8f7c02f2
编写于
3月 30, 2022
作者:
H
Haohongxiang
提交者:
GitHub
3月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Op] Fix uncontrolled randomness of index_select op (#41078)
* fix uncontrolled randomness of op * fix bugs
上级
eef46770
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
48 addition
and
33 deletion
+48
-33
paddle/phi/kernels/gpu/index_select_grad_kernel.cu
paddle/phi/kernels/gpu/index_select_grad_kernel.cu
+37
-33
python/paddle/fluid/tests/unittests/test_index_select_op.py
python/paddle/fluid/tests/unittests/test_index_select_op.py
+11
-0
未找到文件。
paddle/phi/kernels/gpu/index_select_grad_kernel.cu
浏览文件 @
8f7c02f2
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
namespace
phi
{
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
...
@@ -32,16 +34,14 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
...
@@ -32,16 +34,14 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t
stride
,
int64_t
stride
,
int64_t
size
,
int64_t
size
,
int64_t
delta
)
{
int64_t
delta
)
{
int64_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
CUDA_KERNEL_LOOP
(
idx
,
N
)
{
if
(
idx
>=
N
)
{
int64_t
pre_idx
=
idx
/
(
stride
*
size
);
return
;
int64_t
dim_idx
=
idx
%
(
stride
*
size
)
/
stride
;
IndexT
src_dim_idx
=
index
[
dim_idx
];
int64_t
input_idx
=
idx
+
(
delta
*
pre_idx
+
src_dim_idx
-
dim_idx
)
*
stride
;
paddle
::
platform
::
CudaAtomicAdd
(
&
input_grad
[
input_idx
],
output_grad
[
idx
]);
}
}
int64_t
pre_idx
=
idx
/
(
stride
*
size
);
int64_t
dim_idx
=
idx
%
(
stride
*
size
)
/
stride
;
IndexT
src_dim_idx
=
index
[
dim_idx
];
int64_t
input_idx
=
idx
+
(
delta
*
pre_idx
+
src_dim_idx
-
dim_idx
)
*
stride
;
paddle
::
platform
::
CudaAtomicAdd
(
&
input_grad
[
input_idx
],
output_grad
[
idx
]);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -95,34 +95,38 @@ void IndexSelectGradKernel(const Context& ctx,
...
@@ -95,34 +95,38 @@ void IndexSelectGradKernel(const Context& ctx,
0
,
0
,
stream
>>>
(
in_grad_data
,
numel
);
stream
>>>
(
in_grad_data
,
numel
);
int
blocks
=
(
out_nums
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
;
int
threads
=
PADDLE_CUDA_NUM_THREADS
;
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of index_select with single thread."
;
blocks
=
1
;
threads
=
1
;
}
if
(
index_type
==
phi
::
DataType
::
INT64
)
{
if
(
index_type
==
phi
::
DataType
::
INT64
)
{
const
int64_t
*
index_data
=
index
.
data
<
int64_t
>
();
const
int64_t
*
index_data
=
index
.
data
<
int64_t
>
();
index_select_grad_cuda_kernel
<
T
,
int64_t
><<<
index_select_grad_cuda_kernel
<
T
,
int64_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
(
out_nums
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
output_grad_data
,
PADDLE_CUDA_NUM_THREADS
,
in_grad_data
,
0
,
index_data
,
stream
>>>
(
output_grad_data
,
index_nums
,
in_grad_data
,
out_nums
,
index_data
,
stride
,
index_nums
,
size
,
out_nums
,
delta
);
stride
,
size
,
delta
);
}
else
{
}
else
{
const
int
*
index_data
=
index
.
data
<
int
>
();
const
int
*
index_data
=
index
.
data
<
int
>
();
index_select_grad_cuda_kernel
<
T
,
int
><<<
index_select_grad_cuda_kernel
<
T
,
int
><<<
blocks
,
threads
,
0
,
stream
>>>
(
(
out_nums
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
output_grad_data
,
PADDLE_CUDA_NUM_THREADS
,
in_grad_data
,
0
,
index_data
,
stream
>>>
(
output_grad_data
,
index_nums
,
in_grad_data
,
out_nums
,
index_data
,
stride
,
index_nums
,
size
,
out_nums
,
delta
);
stride
,
size
,
delta
);
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/test_index_select_op.py
浏览文件 @
8f7c02f2
...
@@ -69,6 +69,17 @@ class TestIndexSelectOpCase2(TestIndexSelectOp):
...
@@ -69,6 +69,17 @@ class TestIndexSelectOpCase2(TestIndexSelectOp):
self
.
index_size
=
10
self
.
index_size
=
10
class
TestIndexSelectOpCaseSingleThread
(
TestIndexSelectOp
):
def
init_dtype_type
(
self
):
if
fluid
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
'FLAGS_cudnn_deterministic'
:
True
})
self
.
x_type
=
np
.
float32
self
.
index_type
=
np
.
int32
self
.
dim
=
-
2
self
.
x_shape
=
(
10
,
10
,
4
,
10
)
self
.
index_size
=
10
class
TestIndexSelectAPI
(
unittest
.
TestCase
):
class
TestIndexSelectAPI
(
unittest
.
TestCase
):
def
input_data
(
self
):
def
input_data
(
self
):
self
.
data_x
=
np
.
array
([[
1.0
,
2.0
,
3.0
,
4.0
],
[
5.0
,
6.0
,
7.0
,
8.0
],
self
.
data_x
=
np
.
array
([[
1.0
,
2.0
,
3.0
,
4.0
],
[
5.0
,
6.0
,
7.0
,
8.0
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录