Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6e1e036a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
6e1e036a
编写于
2月 03, 2021
作者:
J
JamesLim
提交者:
GitHub
2月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement cuda kernel for index_sample. (#30380)
上级
666efc23
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
165 addition
and
2 deletion
+165
-2
paddle/fluid/operators/index_sample_op.cu
paddle/fluid/operators/index_sample_op.cu
+163
-0
python/paddle/fluid/tests/unittests/test_index_sample_op.py
python/paddle/fluid/tests/unittests/test_index_sample_op.py
+2
-2
未找到文件。
paddle/fluid/operators/index_sample_op.cu
浏览文件 @
6e1e036a
...
...
@@ -12,7 +12,170 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_sample_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
IndexSampleForward
(
const
IndexT
*
index
,
const
T
*
in_data
,
T
*
out_data
,
size_t
index_length
,
size_t
input_length
,
size_t
batch_size
)
{
int
index_i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
index_j
=
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
;
int
index_idx
=
index_j
*
index_length
+
index_i
;
int
in_idx
=
index_j
*
input_length
+
index_i
;
if
(
index_i
<
index_length
&
index_j
<
batch_size
)
{
IndexT
sample_idx
=
index
[
index_idx
];
out_data
[
index_idx
]
=
in_data
[
in_idx
-
index_i
+
sample_idx
];
}
}
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
IndexSampleGrad
(
const
IndexT
*
index
,
T
*
in_grad
,
const
T
*
out_grad
,
size_t
index_length
,
size_t
input_length
,
size_t
batch_size
,
bool
same_data_in_row
=
true
)
{
int
index_i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
index_j
=
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
;
int
index_idx
=
index_j
*
index_length
+
index_i
;
int
in_idx
=
index_j
*
input_length
+
index_i
;
if
(
index_i
<
index_length
&
index_j
<
batch_size
)
{
IndexT
sample_idx
=
index
[
index_idx
];
if
(
same_data_in_row
)
{
platform
::
CudaAtomicAdd
(
&
(
in_grad
[
in_idx
-
index_i
+
sample_idx
]),
out_grad
[
sample_idx
]);
}
else
{
in_grad
[
in_idx
-
index_i
+
sample_idx
]
=
out_grad
[
sample_idx
];
}
}
}
template
<
typename
T
>
class
IndexSampleKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
LoDTensor
>
(
"Index"
);
auto
*
output
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT64
||
index_type
==
framework
::
proto
::
VarType
::
INT32
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
const
auto
*
in_data
=
input
->
data
<
T
>
();
auto
*
out_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
stream
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
auto
input_dim
=
input
->
dims
();
auto
index_dim
=
index
->
dims
();
size_t
batch_size
=
input_dim
[
0
];
size_t
input_length
=
input_dim
[
1
];
size_t
index_length
=
index_dim
[
1
];
auto
block_width
=
platform
::
RoundToPowerOfTwo
(
index_length
);
int
block_height
=
platform
::
RoundToPowerOfTwo
(
index_length
*
batch_size
)
/
block_width
;
dim3
block_dim
(
block_width
,
block_height
);
dim3
grid_dim
((
index_length
+
block_dim
.
x
-
1
)
/
block_dim
.
x
,
(
batch_size
+
block_dim
.
y
-
1
)
/
block_dim
.
y
);
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
const
int64_t
*
index_data
=
index
->
data
<
int64_t
>
();
IndexSampleForward
<
T
,
int64_t
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
index_data
,
in_data
,
out_data
,
index_length
,
input_length
,
batch_size
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
const
int
*
index_data
=
index
->
data
<
int
>
();
IndexSampleForward
<
T
,
int
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
index_data
,
in_data
,
out_data
,
index_length
,
input_length
,
batch_size
);
}
}
};
template
<
typename
T
>
class
IndexSampleGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
output_grad
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
input_grad
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
index
=
ctx
.
Input
<
LoDTensor
>
(
"Index"
);
const
auto
*
output_grad_data
=
output_grad
->
data
<
T
>
();
auto
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT64
||
index_type
==
framework
::
proto
::
VarType
::
INT32
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
auto
stream
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
auto
input_num
=
input_grad
->
numel
();
auto
input_dim
=
input_grad
->
dims
();
auto
index_dim
=
index
->
dims
();
size_t
batch_size
=
index_dim
[
0
];
size_t
input_length
=
input_dim
[
1
];
size_t
index_length
=
index_dim
[
1
];
bool
same_data_in_index_row
=
index_length
==
1
?
false
:
true
;
auto
block_width
=
platform
::
RoundToPowerOfTwo
(
index_length
);
auto
block_height
=
platform
::
RoundToPowerOfTwo
(
index_length
*
batch_size
)
/
block_width
;
dim3
block_dim
(
block_width
,
block_height
);
dim3
grid_dim
((
index_length
+
block_dim
.
x
-
1
)
/
block_dim
.
x
,
(
batch_size
+
block_dim
.
y
-
1
)
/
block_dim
.
y
);
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
set_zero
(
dev_ctx
,
input_grad
,
static_cast
<
T
>
(
0
));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
const
int64_t
*
index_data
=
index
->
data
<
int64_t
>
();
IndexSampleGrad
<
T
,
int64_t
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
index_data
,
input_grad_data
,
output_grad_data
,
index_length
,
input_length
,
batch_size
,
same_data_in_index_row
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
const
int
*
index_data
=
index
->
data
<
int
>
();
IndexSampleGrad
<
T
,
int
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
index_data
,
input_grad_data
,
output_grad_data
,
index_length
,
input_length
,
batch_size
,
same_data_in_index_row
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
...
...
python/paddle/fluid/tests/unittests/test_index_sample_op.py
浏览文件 @
6e1e036a
...
...
@@ -92,9 +92,9 @@ class TestCase4(TestIndexSampleOp):
"""
For int64 index type
"""
self
.
x_shape
=
(
10
,
1
00
)
self
.
x_shape
=
(
10
,
1
28
)
self
.
x_type
=
"float64"
self
.
index_shape
=
(
10
,
10
)
self
.
index_shape
=
(
10
,
64
)
self
.
index_type
=
"int64"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录