Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
aac00f6a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aac00f6a
编写于
11月 09, 2021
作者:
H
Haohongxiang
提交者:
GitHub
11月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize backward (#37055)
上级
71816707
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
27 addition
and
14 deletion
+27
-14
paddle/fluid/operators/index_select_op.cu
paddle/fluid/operators/index_select_op.cu
+27
-14
未找到文件。
paddle/fluid/operators/index_select_op.cu
浏览文件 @
aac00f6a
...
...
@@ -54,14 +54,18 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t
pre_idx
=
idx
/
(
stride
*
size
);
int64_t
dim_idx
=
idx
%
(
stride
*
size
)
/
stride
;
int64_t
begin_idx
=
idx
+
(
delta
*
pre_idx
-
dim_idx
)
*
stride
;
IndexT
src_dim_idx
=
index
[
dim_idx
];
int64_t
input_idx
=
idx
+
(
delta
*
pre_idx
+
src_dim_idx
-
dim_idx
)
*
stride
;
paddle
::
platform
::
CudaAtomicAdd
(
&
input_grad
[
input_idx
],
output_grad
[
idx
]);
}
input_grad
[
idx
]
=
0.0
;
for
(
int64_t
i
=
0
;
i
<
nums
;
i
++
)
{
if
(
index
[
i
]
==
dim_idx
)
{
input_grad
[
idx
]
+=
output_grad
[
begin_idx
+
i
*
stride
];
}
template
<
typename
T
>
__global__
void
index_select_grad_init
(
T
*
input_grad
,
int64_t
N
)
{
int64_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
>=
N
)
{
return
;
}
input_grad
[
idx
]
=
0.0
;
}
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -143,8 +147,8 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
dim
=
dim
>=
0
?
dim
:
dim
+
input_dim
.
size
();
auto
stride_dim
=
framework
::
stride
(
input_dim
);
int64_t
stride
=
stride_dim
[
dim
];
int64_t
size
=
in
put_dim
[
dim
];
int64_t
delta
=
out
put_dim
[
dim
]
-
size
;
int64_t
size
=
out
put_dim
[
dim
];
int64_t
delta
=
in
put_dim
[
dim
]
-
size
;
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT64
||
...
...
@@ -161,17 +165,22 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
int64_t
numel
=
in_grad
->
numel
();
int64_t
index_nums
=
index
->
numel
();
int64_t
out_nums
=
output_grad
->
numel
();
auto
stream
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
index_select_grad_init
<
T
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_grad_data
,
numel
);
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
const
int64_t
*
index_data
=
index
->
data
<
int64_t
>
();
index_select_grad_cuda_kernel
<
T
,
int64_t
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
(
out_nums
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
index_data
,
index_nums
,
numel
,
stride
,
size
,
delta
);
index_data
,
index_nums
,
out_nums
,
stride
,
size
,
delta
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
...
...
@@ -180,10 +189,10 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
}
else
{
const
int
*
index_data
=
index
->
data
<
int
>
();
index_select_grad_cuda_kernel
<
T
,
int
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
(
out_nums
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
index_data
,
index_nums
,
numel
,
stride
,
size
,
delta
);
index_data
,
index_nums
,
out_nums
,
stride
,
size
,
delta
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
...
...
@@ -201,12 +210,16 @@ REGISTER_OP_CUDA_KERNEL(
index_select
,
ops
::
IndexSelectCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
IndexSelectCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
IndexSelectCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
,
ops
::
IndexSelectCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
IndexSelectCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
index_select_grad
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录