Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
3d73dea9
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3d73dea9
编写于
11月 18, 2019
作者:
P
Pei Yang
提交者:
GitHub
11月 18, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sequence_pool cuda kernel, test=develop (#2430)
add sequence_pool cuda kernel
上级
1e88d1e8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
436 addition
and
0 deletion
+436
-0
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+2
-0
lite/kernels/cuda/sequence_pool_compute.cu
lite/kernels/cuda/sequence_pool_compute.cu
+265
-0
lite/kernels/cuda/sequence_pool_compute.h
lite/kernels/cuda/sequence_pool_compute.h
+35
-0
lite/kernels/cuda/sequence_pool_compute_test.cc
lite/kernels/cuda/sequence_pool_compute_test.cc
+134
-0
未找到文件。
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
3d73dea9
...
...
@@ -9,6 +9,7 @@ add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_k
add_kernel
(
leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
cuda_transpose
)
add_kernel
(
nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
...
...
@@ -38,6 +39,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test
(
transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda
)
nv_test
(
concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda
)
nv_test
(
elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda
)
nv_test
(
sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda sequence_pooling
)
nv_test
(
softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda
)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test
(
mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda
)
...
...
lite/kernels/cuda/sequence_pool_compute.cu
0 → 100644
浏览文件 @
3d73dea9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_pool_compute.h"
const
int
CUDA_NUM_THREADS
=
512
;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
/// CUDA: number of blocks for threads.
inline
int
CUDA_GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
}
inline
int
CUDA_GET_BLOCKS
(
const
int
N
,
const
int
base
)
{
return
(
N
+
base
-
1
)
/
base
;
}
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
Dtype
>
__global__
void
seq_pool_average_kernel
(
Dtype
*
dst
,
const
Dtype
*
src_in
,
const
int
batch_size
,
const
uint64_t
*
seq_offset
,
const
int
slice_size
)
{
int
total
=
slice_size
*
batch_size
;
CUDA_KERNEL_LOOP
(
tid
,
total
)
{
int
out_batch_id
=
tid
/
slice_size
;
int
out_id
=
tid
%
slice_size
;
int
in_slice_num
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
+
1
]
-
seq_offset
[
out_batch_id
]);
int
in_offset
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
]
*
slice_size
);
src_in
+=
in_offset
+
out_id
;
Dtype
sum
=
(
Dtype
)
0
;
for
(
int
i
=
0
;
i
<
in_slice_num
;
++
i
)
{
sum
+=
src_in
[
i
*
slice_size
];
}
dst
[
out_batch_id
*
slice_size
+
out_id
]
=
sum
/
in_slice_num
;
}
}
template
<
typename
Dtype
>
__global__
void
seq_pool_sum_kernel
(
Dtype
*
dst
,
const
Dtype
*
src_in
,
const
int
batch_size
,
const
uint64_t
*
seq_offset
,
const
int
slice_size
)
{
int
total
=
slice_size
*
batch_size
;
CUDA_KERNEL_LOOP
(
tid
,
total
)
{
int
out_batch_id
=
tid
/
slice_size
;
int
out_id
=
tid
%
slice_size
;
int
in_slice_num
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
+
1
]
-
seq_offset
[
out_batch_id
]);
int
in_offset
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
]
*
slice_size
);
src_in
+=
in_offset
+
out_id
;
Dtype
sum
=
(
Dtype
)
0
;
for
(
int
i
=
0
;
i
<
in_slice_num
;
++
i
)
{
sum
+=
src_in
[
i
*
slice_size
];
}
dst
[
out_batch_id
*
slice_size
+
out_id
]
=
sum
;
}
}
template
<
typename
Dtype
>
__global__
void
seq_pool_sqrt_kernel
(
Dtype
*
dst
,
const
Dtype
*
src_in
,
const
int
batch_size
,
const
uint64_t
*
seq_offset
,
const
int
slice_size
)
{
int
total
=
slice_size
*
batch_size
;
CUDA_KERNEL_LOOP
(
tid
,
total
)
{
int
out_batch_id
=
tid
/
slice_size
;
int
out_id
=
tid
%
slice_size
;
int
in_slice_num
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
+
1
]
-
seq_offset
[
out_batch_id
]);
int
in_offset
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
]
*
slice_size
);
src_in
+=
in_offset
+
out_id
;
Dtype
sum
=
(
Dtype
)
0
;
for
(
int
i
=
0
;
i
<
in_slice_num
;
++
i
)
{
sum
+=
src_in
[
i
*
slice_size
];
}
dst
[
out_batch_id
*
slice_size
+
out_id
]
=
sum
*
rsqrtf
(
in_slice_num
);
}
}
template
<
typename
Dtype
>
__global__
void
seq_pool_max_kernel
(
Dtype
*
dst
,
const
Dtype
*
src_in
,
const
int
batch_size
,
const
uint64_t
*
seq_offset
,
const
int
slice_size
)
{
int
total
=
slice_size
*
batch_size
;
CUDA_KERNEL_LOOP
(
tid
,
total
)
{
int
out_batch_id
=
tid
/
slice_size
;
int
out_id
=
tid
%
slice_size
;
int
in_slice_num
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
+
1
]
-
seq_offset
[
out_batch_id
]);
int
in_offset
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
]
*
slice_size
);
src_in
+=
in_offset
+
out_id
;
Dtype
max
=
src_in
[
0
];
for
(
int
i
=
1
;
i
<
in_slice_num
;
++
i
)
{
Dtype
val
=
src_in
[
i
*
slice_size
];
if
(
val
>
max
)
{
max
=
val
;
}
}
dst
[
out_batch_id
*
slice_size
+
out_id
]
=
max
;
}
}
template
<
typename
Dtype
>
__global__
void
seq_pool_last_kernel
(
Dtype
*
dst
,
const
Dtype
*
src_in
,
const
int
batch_size
,
const
uint64_t
*
seq_offset
,
const
int
slice_size
)
{
int
total
=
slice_size
*
batch_size
;
CUDA_KERNEL_LOOP
(
tid
,
total
)
{
int
out_batch_id
=
tid
/
slice_size
;
int
out_id
=
tid
%
slice_size
;
int
in_offset
=
(
static_cast
<
int
>
(
seq_offset
[
out_batch_id
+
1
])
-
1
)
*
slice_size
;
dst
[
tid
]
=
src_in
[
in_offset
+
out_id
];
}
}
template
<
typename
Dtype
>
__global__
void
seq_pool_first_kernel
(
Dtype
*
dst
,
const
Dtype
*
src_in
,
const
int
batch_size
,
const
uint64_t
*
seq_offset
,
const
int
slice_size
)
{
int
total
=
slice_size
*
batch_size
;
CUDA_KERNEL_LOOP
(
tid
,
total
)
{
int
out_batch_id
=
tid
/
slice_size
;
int
out_id
=
tid
%
slice_size
;
int
in_offset
=
static_cast
<
int
>
(
seq_offset
[
out_batch_id
]
*
slice_size
);
dst
[
tid
]
=
src_in
[
in_offset
+
out_id
];
}
}
void
SequencePoolCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
std
::
vector
<
uint64_t
>
seq_offset
=
param
.
X
->
lod
()[
0
];
int
slice_size
=
param
.
Out
->
dims
()[
1
]
*
param
.
Out
->
dims
()[
2
]
*
param
.
Out
->
dims
()[
3
];
float
*
out_data
=
param
.
Out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
const
float
*
in_data
=
param
.
X
->
data
<
float
>
();
int
batch_size
=
param
.
X
->
lod
().
size
()
-
1
;
lite
::
Tensor
seq_offset_D
;
seq_offset_D
.
Resize
({
static_cast
<
int64_t
>
(
seq_offset
.
size
())});
TargetWrapperCuda
::
MemcpyAsync
(
seq_offset_D
.
mutable_data
<
uint64_t
>
(),
seq_offset
.
data
(),
sizeof
(
uint64_t
)
*
seq_offset
.
size
(),
IoDirection
::
HtoD
,
stream
);
if
(
param
.
pool_type
==
"MAX"
)
{
seq_pool_max_kernel
<
float
><<<
CUDA_GET_BLOCKS
(
batch_size
*
slice_size
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
out_data
,
in_data
,
batch_size
,
seq_offset_D
.
data
<
uint64_t
>
(),
slice_size
);
}
else
if
(
param
.
pool_type
==
"AVERAGE "
)
{
seq_pool_average_kernel
<
float
><<<
CUDA_GET_BLOCKS
(
batch_size
*
slice_size
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
out_data
,
in_data
,
batch_size
,
seq_offset_D
.
data
<
uint64_t
>
(),
slice_size
);
}
else
if
(
param
.
pool_type
==
"SUM"
)
{
seq_pool_sum_kernel
<
float
><<<
CUDA_GET_BLOCKS
(
batch_size
*
slice_size
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
out_data
,
in_data
,
batch_size
,
seq_offset_D
.
data
<
uint64_t
>
(),
slice_size
);
}
else
if
(
param
.
pool_type
==
"SQRT"
)
{
seq_pool_sqrt_kernel
<
float
><<<
CUDA_GET_BLOCKS
(
batch_size
*
slice_size
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
out_data
,
in_data
,
batch_size
,
seq_offset_D
.
data
<
uint64_t
>
(),
slice_size
);
}
else
if
(
param
.
pool_type
==
"FIRST"
)
{
seq_pool_first_kernel
<
float
><<<
CUDA_GET_BLOCKS
(
batch_size
*
slice_size
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
out_data
,
in_data
,
batch_size
,
seq_offset_D
.
data
<
uint64_t
>
(),
slice_size
);
}
else
if
(
param
.
pool_type
==
"LAST"
)
{
seq_pool_last_kernel
<
float
><<<
CUDA_GET_BLOCKS
(
batch_size
*
slice_size
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
out_data
,
in_data
,
batch_size
,
seq_offset_D
.
data
<
uint64_t
>
(),
slice_size
);
}
else
{
LOG
(
ERROR
)
<<
"pool type "
<<
param
.
pool_type
<<
" is not supoorted."
;
}
std
::
vector
<
uint64_t
>
offset_new
(
static_cast
<
uint64_t
>
(
batch_size
+
1
));
for
(
int
i
=
0
;
i
<=
batch_size
;
++
i
)
{
offset_new
[
i
]
=
i
;
}
std
::
vector
<
std
::
vector
<
uint64_t
>>
voffset_new
;
voffset_new
.
push_back
(
offset_new
);
param
.
Out
->
set_lod
(
voffset_new
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
sequence_pool
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
SequencePoolCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
lite/kernels/cuda/sequence_pool_compute.h
0 → 100644
浏览文件 @
3d73dea9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
SequencePoolCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
SequencePoolParam
;
void
Run
()
override
;
virtual
~
SequencePoolCompute
()
=
default
;
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/sequence_pool_compute_test.cc
0 → 100644
浏览文件 @
3d73dea9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_pool_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/backends/x86/math/sequence_pooling.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
namespace
{
static
void
sequence_pool_ref
(
const
operators
::
SequencePoolParam
&
param
,
)
{
auto
*
x
=
param
.
X
;
auto
*
out
=
param
.
Out
;
auto
dims
=
x
->
dims
();
auto
lod
=
x
->
lod
();
CHECK_EQ
(
lod
.
size
(),
1UL
);
CHECK_GE
(
dims
[
0
],
static_cast
<
int64_t
>
(
lod
[
0
].
size
()
-
1
));
dims
[
0
]
=
lod
[
0
].
size
()
-
1
;
out
->
Resize
({
dims
});
out
->
mutable_data
<
float
>
();
lite
::
Tensor
*
index
=
nullptr
;
const
bool
is_test
=
true
;
float
pad_value
=
0.0
;
lite
::
x86
::
math
::
SequencePoolFunctor
<
lite
::
TargetType
::
kX86
,
float
>
pool
;
pool
(
context
,
param
.
pool_type
,
pad_value
,
*
x
,
out
,
is_test
,
index
);
}
#define PREPARE_INPUT_DATA(name) \
name.Resize({name##_lod_len, feature_len}); \
name##_cpu.Resize({name##_lod_len, feature_len}); \
name##_ref.Resize({name##_lod_len, feature_len}); \
name.set_lod(lod_info_##name); \
name##_cpu.set_lod(lod_info_##name); \
name##_ref.set_lod(lod_info_##name); \
float* name##_cpu_data = name##_cpu.mutable_data<float>(); \
float* name##_ref_data = name##_ref.mutable_data<float>(); \
for (int i = 0; i < name##_cpu.numel(); ++i) { \
name##_cpu_data[i] = (i - 2.0) * 1.0; \
name##_ref_data[i] = (i - 2.0) * 1.0; \
} \
name.Assign<float, lite::DDim, TARGET(kCUDA)>(name##_cpu_data, \
name##_cpu.dims());
#define PREPARE_OUTPUT_INFO(name) \
name##_cpu.Resize({y_lod_len, feature_len}); \
name##_ref.Resize({y_lod_len, feature_len}); \
name.Resize({y_lod_len, feature_len}); \
float* name##_cpu_data = name##_cpu.mutable_data<float>();
}
// namespace
TEST
(
sequence_pool_cuda
,
normal
)
{
SequencePoolCompute
seq_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
std
::
unique_ptr
<
KernelContext
>
ctx_ref
(
new
KernelContext
);
auto
&
context_ref
=
ctx_ref
->
As
<
X86Context
>
();
operators
::
SequencePoolParam
param
;
lite
::
Tensor
x1
,
x2
,
x3
,
x1_cpu
,
x2_cpu
,
x3_cpu
,
x1_ref
,
x2_ref
,
x3_ref
;
lite
::
Tensor
y
,
y_cpu
,
y_ref
;
int32_t
x1_lod_len
=
10
,
feature_len
=
4
;
int32_t
x2_lod_len
=
4
,
x3_lod_len
=
8
;
int32_t
y_lod_len
=
x1_lod_len
+
x2_lod_len
+
x3_lod_len
;
LoD
lod_info_x1
{{
0
,
3
,
5
,
6
,
10
}};
LoD
lod_info_x2
{{
0
,
1
,
2
,
3
,
4
}};
LoD
lod_info_x3
{{
0
,
2
,
4
,
6
,
8
}};
LoD
lod_info_y
{{
0
,
0
,
0
,
0
,
0
}};
for
(
size_t
i
=
0
;
i
<
lod_info_x1
[
0
].
size
();
++
i
)
{
lod_info_y
[
0
][
i
]
=
lod_info_x1
[
0
][
i
]
+
lod_info_x2
[
0
][
i
]
+
lod_info_x3
[
0
][
i
];
}
PREPARE_INPUT_DATA
(
x1
);
PREPARE_INPUT_DATA
(
x2
);
PREPARE_INPUT_DATA
(
x3
);
PREPARE_OUTPUT_INFO
(
y
);
param
.
X
=
&
x1
;
param
.
Out
=
&
y
;
param
.
pool_type
=
"AVERAGE"
;
seq_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
seq_kernel
.
SetContext
(
std
::
move
(
ctx
));
seq_kernel
.
Run
();
cudaDeviceSynchronize
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
y_cpu_data
,
y_data
,
sizeof
(
float
)
*
y
.
numel
(),
IoDirection
::
DtoH
);
param
.
X
=
&
x1_ref
;
param
.
Out
=
&
y_ref
;
sequence_pool_ref
(
param
);
lite
::
x86
::
math
::
SequencePoolFunctor
<
lite
::
TargetType
::
kX86
,
float
>
pool
;
pool
(
context
,
param
.
pool_type
,
pad_value
,
*
x
,
out
,
is_test
,
index
);
float
*
y_ref_data
=
y_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
y
.
numel
();
i
++
)
{
EXPECT_NEAR
(
y_cpu_data
[
i
],
y_ref_data
[
i
],
1e-5
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录