Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c5b415bf
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c5b415bf
编写于
1月 07, 2021
作者:
1
123malin
提交者:
GitHub
1月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve Index select cuda kernel (#30139)
* test=develop, add index_select_cuda kernel
上级
7dd551e0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
175 addition
and
8 deletion
+175
-8
paddle/fluid/operators/index_select_op.cu
paddle/fluid/operators/index_select_op.cu
+175
-8
未找到文件。
paddle/fluid/operators/index_select_op.cu
浏览文件 @
c5b415bf
...
...
@@ -12,18 +12,185 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_select_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
,
typename
IndexT
>
__global__
void
index_select_cuda_kernel
(
const
T
*
input
,
T
*
output
,
const
IndexT
*
index
,
int64_t
N
,
int64_t
stride
,
int64_t
size
,
int64_t
delta
)
{
int64_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
>=
N
)
{
return
;
}
int64_t
pre_idx
=
idx
/
(
stride
*
size
);
int64_t
dim_idx
=
idx
%
(
stride
*
size
)
/
stride
;
IndexT
src_dim_idx
=
index
[
dim_idx
];
int64_t
input_idx
=
idx
+
(
delta
*
pre_idx
+
src_dim_idx
-
dim_idx
)
*
stride
;
output
[
idx
]
=
input
[
input_idx
];
}
template
<
typename
T
,
typename
IndexT
>
__global__
void
index_select_grad_cuda_kernel
(
const
T
*
output_grad
,
T
*
input_grad
,
const
IndexT
*
index
,
int64_t
nums
,
int64_t
N
,
int64_t
stride
,
int64_t
size
,
int64_t
delta
)
{
int64_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
>=
N
)
{
return
;
}
int64_t
pre_idx
=
idx
/
(
stride
*
size
);
int64_t
dim_idx
=
idx
%
(
stride
*
size
)
/
stride
;
int64_t
begin_idx
=
idx
+
(
delta
*
pre_idx
-
dim_idx
)
*
stride
;
input_grad
[
idx
]
=
0.0
;
for
(
int64_t
i
=
0
;
i
<
nums
;
i
++
)
{
if
(
index
[
i
]
==
dim_idx
)
{
input_grad
[
idx
]
+=
output_grad
[
begin_idx
+
i
*
stride
];
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
IndexSelectCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
index
=
context
.
Input
<
LoDTensor
>
(
"Index"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int
dim
=
context
.
Attr
<
int
>
(
"dim"
);
auto
input_dim
=
in
->
dims
();
auto
output_dim
=
out
->
dims
();
dim
=
dim
>=
0
?
dim
:
dim
+
input_dim
.
size
();
auto
stride_dim
=
framework
::
stride
(
input_dim
);
int64_t
stride
=
stride_dim
[
dim
];
int64_t
size
=
output_dim
[
dim
];
int64_t
delta
=
input_dim
[
dim
]
-
size
;
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT64
||
index_type
==
framework
::
proto
::
VarType
::
INT32
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
auto
*
in_data
=
in
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
numel
=
out
->
numel
();
auto
stream
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
const
int64_t
*
index_data
=
index
->
data
<
int64_t
>
();
index_select_cuda_kernel
<
T
,
int64_t
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
out_data
,
index_data
,
numel
,
stride
,
size
,
delta
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
}
else
{
const
int
*
index_data
=
index
->
data
<
int
>
();
index_select_cuda_kernel
<
T
,
int
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
out_data
,
index_data
,
numel
,
stride
,
size
,
delta
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
IndexSelectGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
output_grad
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
in_grad
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
index
=
context
.
Input
<
LoDTensor
>
(
"Index"
);
auto
*
output_grad_data
=
output_grad
->
data
<
T
>
();
auto
*
in_grad_data
=
in_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
dim
=
context
.
Attr
<
int
>
(
"dim"
);
auto
input_dim
=
in_grad
->
dims
();
auto
output_dim
=
output_grad
->
dims
();
dim
=
dim
>=
0
?
dim
:
dim
+
input_dim
.
size
();
auto
stride_dim
=
framework
::
stride
(
input_dim
);
int64_t
stride
=
stride_dim
[
dim
];
int64_t
size
=
input_dim
[
dim
];
int64_t
delta
=
output_dim
[
dim
]
-
size
;
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT64
||
index_type
==
framework
::
proto
::
VarType
::
INT32
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
int64_t
numel
=
in_grad
->
numel
();
int64_t
index_nums
=
index
->
numel
();
auto
stream
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
const
int64_t
*
index_data
=
index
->
data
<
int64_t
>
();
index_select_grad_cuda_kernel
<
T
,
int64_t
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
index_data
,
index_nums
,
numel
,
stride
,
size
,
delta
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
}
else
{
const
int
*
index_data
=
index
->
data
<
int
>
();
index_select_grad_cuda_kernel
<
T
,
int
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
index_data
,
index_nums
,
numel
,
stride
,
size
,
delta
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
index_select
,
ops
::
IndexSelectKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
IndexSelectKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
IndexSelectKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
IndexSelectKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
ops
::
IndexSelect
CUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
IndexSelect
CUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
IndexSelect
CUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
IndexSelect
CUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
index_select_grad
,
ops
::
IndexSelectGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
IndexSelectGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
IndexSelectGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
IndexSelectGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
IndexSelectGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录