Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
aa67c28e
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aa67c28e
编写于
12月 04, 2019
作者:
W
Wilber
提交者:
GitHub
12月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update cuda kernels to run content-dnn models test=develop (#2554)
update cuda kernels to run content-dnn model
上级
f7574646
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
216 addition
and
79 deletion
+216
-79
lite/api/paddle_api.cc
lite/api/paddle_api.cc
+1
-0
lite/core/op_registry.cc
lite/core/op_registry.cc
+2
-0
lite/kernels/cuda/attention_padding_mask_compute.cu
lite/kernels/cuda/attention_padding_mask_compute.cu
+39
-3
lite/kernels/cuda/feed_compute.cc
lite/kernels/cuda/feed_compute.cc
+36
-9
lite/kernels/cuda/feed_compute.h
lite/kernels/cuda/feed_compute.h
+2
-1
lite/kernels/cuda/match_matrix_tensor_compute.cu
lite/kernels/cuda/match_matrix_tensor_compute.cu
+24
-0
lite/kernels/cuda/search_aligned_mat_mul_compute.cc
lite/kernels/cuda/search_aligned_mat_mul_compute.cc
+3
-0
lite/kernels/cuda/search_fc_compute.cu
lite/kernels/cuda/search_fc_compute.cu
+0
-6
lite/kernels/cuda/search_group_padding_compute.cu
lite/kernels/cuda/search_group_padding_compute.cu
+11
-6
lite/kernels/cuda/search_seq_depadding_compute.cu
lite/kernels/cuda/search_seq_depadding_compute.cu
+7
-2
lite/kernels/cuda/sequence_arithmetic_compute.cu
lite/kernels/cuda/sequence_arithmetic_compute.cu
+1
-2
lite/kernels/cuda/sequence_concat_compute.cu
lite/kernels/cuda/sequence_concat_compute.cu
+20
-0
lite/kernels/cuda/sequence_pool_compute.cu
lite/kernels/cuda/sequence_pool_compute.cu
+1
-0
lite/kernels/cuda/sequence_reverse_compute.cu
lite/kernels/cuda/sequence_reverse_compute.cu
+26
-21
lite/kernels/cuda/sequence_reverse_compute.h
lite/kernels/cuda/sequence_reverse_compute.h
+2
-2
lite/kernels/cuda/sequence_reverse_compute_test.cc
lite/kernels/cuda/sequence_reverse_compute_test.cc
+1
-1
lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu
lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu
+34
-21
lite/kernels/cuda/softmax_compute.cu
lite/kernels/cuda/softmax_compute.cu
+5
-3
lite/operators/sequence_topk_avg_pooling_op.cc
lite/operators/sequence_topk_avg_pooling_op.cc
+1
-2
未找到文件。
lite/api/paddle_api.cc
浏览文件 @
aa67c28e
...
@@ -121,6 +121,7 @@ template void Tensor::CopyFromCpu<int, TargetType::kARM>(const int *);
...
@@ -121,6 +121,7 @@ template void Tensor::CopyFromCpu<int, TargetType::kARM>(const int *);
template
void
Tensor
::
CopyFromCpu
<
float
,
TargetType
::
kARM
>(
const
float
*
);
template
void
Tensor
::
CopyFromCpu
<
float
,
TargetType
::
kARM
>(
const
float
*
);
template
void
Tensor
::
CopyFromCpu
<
int8_t
,
TargetType
::
kARM
>(
const
int8_t
*
);
template
void
Tensor
::
CopyFromCpu
<
int8_t
,
TargetType
::
kARM
>(
const
int8_t
*
);
template
void
Tensor
::
CopyFromCpu
<
int
,
TargetType
::
kCUDA
>(
const
int
*
);
template
void
Tensor
::
CopyFromCpu
<
int
,
TargetType
::
kCUDA
>(
const
int
*
);
template
void
Tensor
::
CopyFromCpu
<
int64_t
,
TargetType
::
kCUDA
>(
const
int64_t
*
);
template
void
Tensor
::
CopyFromCpu
<
float
,
TargetType
::
kCUDA
>(
const
float
*
);
template
void
Tensor
::
CopyFromCpu
<
float
,
TargetType
::
kCUDA
>(
const
float
*
);
template
void
Tensor
::
CopyFromCpu
<
int8_t
,
TargetType
::
kCUDA
>(
const
int8_t
*
);
template
void
Tensor
::
CopyFromCpu
<
int8_t
,
TargetType
::
kCUDA
>(
const
int8_t
*
);
...
...
lite/core/op_registry.cc
浏览文件 @
aa67c28e
...
@@ -115,6 +115,8 @@ KernelRegistry::KernelRegistry()
...
@@ -115,6 +115,8 @@ KernelRegistry::KernelRegistry()
INIT_FOR
(
kCUDA
,
kAny
,
kNCHW
);
INIT_FOR
(
kCUDA
,
kAny
,
kNCHW
);
INIT_FOR
(
kCUDA
,
kAny
,
kAny
);
INIT_FOR
(
kCUDA
,
kAny
,
kAny
);
INIT_FOR
(
kCUDA
,
kInt8
,
kNHWC
);
INIT_FOR
(
kCUDA
,
kInt8
,
kNHWC
);
INIT_FOR
(
kCUDA
,
kInt64
,
kNCHW
);
INIT_FOR
(
kCUDA
,
kInt64
,
kNHWC
);
INIT_FOR
(
kHost
,
kFloat
,
kNCHW
);
INIT_FOR
(
kHost
,
kFloat
,
kNCHW
);
INIT_FOR
(
kHost
,
kAny
,
kNCHW
);
INIT_FOR
(
kHost
,
kAny
,
kNCHW
);
...
...
lite/kernels/cuda/attention_padding_mask_compute.cu
浏览文件 @
aa67c28e
...
@@ -40,6 +40,7 @@ __global__ void ker_attention_padding_mask(T* out_data,
...
@@ -40,6 +40,7 @@ __global__ void ker_attention_padding_mask(T* out_data,
const
int
attn_seq_len
,
const
int
attn_seq_len
,
const
int
src_seq_num
,
const
int
src_seq_num
,
const
int
src_seq_len
,
const
int
src_seq_len
,
const
T
*
pad_begin_data
,
const
T
mask
,
const
T
mask
,
const
int
count
)
{
const
int
count
)
{
CUDA_KERNEL_LOOP
(
tid
,
count
)
{
CUDA_KERNEL_LOOP
(
tid
,
count
)
{
...
@@ -49,7 +50,12 @@ __global__ void ker_attention_padding_mask(T* out_data,
...
@@ -49,7 +50,12 @@ __global__ void ker_attention_padding_mask(T* out_data,
int
attn_word_id
=
tmp_tid
%
attn_seq_len
;
int
attn_word_id
=
tmp_tid
%
attn_seq_len
;
int
src_seq_id
=
attn_seq_id
%
src_seq_num
;
int
src_seq_id
=
attn_seq_id
%
src_seq_num
;
int
cur_len
=
src_offset
[
src_seq_id
+
1
]
-
src_offset
[
src_seq_id
];
int
cur_len
=
src_offset
[
src_seq_id
+
1
]
-
src_offset
[
src_seq_id
];
if
(
src_word_id
>=
cur_len
)
{
int
k
=
static_cast
<
int
>
(
pad_begin_data
[
src_seq_id
]);
if
(
k
<
cur_len
&&
tid
>=
src_seq_len
*
(
attn_seq_len
*
attn_seq_id
+
attn_word_id
)
+
k
&&
tid
<
src_seq_len
*
(
attn_seq_len
*
attn_seq_id
+
attn_word_id
)
+
cur_len
)
{
out_data
[
tid
]
=
mask
;
out_data
[
tid
]
=
mask
;
}
else
{
}
else
{
out_data
[
tid
]
=
attn_data
[
tid
];
out_data
[
tid
]
=
attn_data
[
tid
];
...
@@ -79,6 +85,35 @@ void AttentionPaddingMaskCompute::Run() {
...
@@ -79,6 +85,35 @@ void AttentionPaddingMaskCompute::Run() {
auto
attn_data
=
attn
->
data
<
float
>
();
auto
attn_data
=
attn
->
data
<
float
>
();
auto
out_data
=
out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
out_data
=
out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
std
::
vector
<
float
>
src_cpu
(
src
->
numel
(),
0
);
TargetWrapperCuda
::
MemcpyAsync
(
src_cpu
.
data
(),
src
->
data
<
float
>
(),
sizeof
(
float
)
*
src
->
numel
(),
IoDirection
::
DtoH
,
stream
);
cudaStreamSynchronize
(
stream
);
std
::
vector
<
float
>
pad_begin
(
src_seq_num
,
0
);
auto
src_len
=
static_cast
<
int64_t
>
(
src
->
lod
()[
0
][
1
]);
int
_pad_id
=
param
.
pad_id
;
for
(
int
i
=
0
;
i
<
src_seq_num
;
++
i
)
{
const
auto
*
src_data
=
src_cpu
.
data
()
+
src_len
*
i
;
int
index
=
src_len
-
1
;
for
(;
index
>=
0
&&
_pad_id
==
static_cast
<
int
>
(
src_data
[
index
]);
--
index
)
{
}
pad_begin
[
i
]
=
static_cast
<
float
>
(
index
+
1
);
}
param
.
pad_begin
->
Resize
({
static_cast
<
int64_t
>
(
src_seq_num
)});
auto
pad_begin_cuda_data
=
param
.
pad_begin
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
pad_begin_cuda_data
,
pad_begin
.
data
(),
sizeof
(
float
)
*
src_seq_num
,
IoDirection
::
HtoD
,
stream
);
std
::
vector
<
int
>
src_offset_cpu
(
src_offset
.
size
(),
0
);
std
::
vector
<
int
>
src_offset_cpu
(
src_offset
.
size
(),
0
);
for
(
int
i
=
0
;
i
<
src_offset
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
src_offset
.
size
();
i
++
)
{
src_offset_cpu
[
i
]
=
src_offset
[
i
];
src_offset_cpu
[
i
]
=
src_offset
[
i
];
...
@@ -101,11 +136,12 @@ void AttentionPaddingMaskCompute::Run() {
...
@@ -101,11 +136,12 @@ void AttentionPaddingMaskCompute::Run() {
attn_seq_len
,
attn_seq_len
,
src_seq_num
,
src_seq_num
,
src_seq_len
,
src_seq_len
,
pad_begin_cuda_data
,
param
.
mask
,
param
.
mask
,
count
);
count
);
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
}
}
}
// namespace cuda
}
// namespace cuda
...
@@ -113,7 +149,7 @@ void AttentionPaddingMaskCompute::Run() {
...
@@ -113,7 +149,7 @@ void AttentionPaddingMaskCompute::Run() {
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
attention_padding_mask
,
REGISTER_LITE_KERNEL
(
search_
attention_padding_mask
,
kCUDA
,
kCUDA
,
kFloat
,
kFloat
,
kNCHW
,
kNCHW
,
...
...
lite/kernels/cuda/feed_compute.cc
浏览文件 @
aa67c28e
...
@@ -20,21 +20,22 @@ namespace lite {
...
@@ -20,21 +20,22 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
cuda
{
namespace
cuda
{
void
FeedCompute
::
Run
()
{
template
<
typename
T
,
PrecisionType
Ptype
>
auto
&
param
=
this
->
Param
<
param_t
>
();
void
FeedCompute
<
T
,
Ptype
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
param_t
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
auto
stream
=
ctx
.
exec_stream
();
VLOG
(
4
)
<<
"feed_list.size: "
<<
param
.
feed_list
->
size
();
VLOG
(
4
)
<<
"feed_list.size: "
<<
param
.
feed_list
->
size
();
const
lite
::
Tensor
&
feed_item
=
(
*
param
.
feed_list
)[
param
.
col
];
const
lite
::
Tensor
&
feed_item
=
(
*
param
.
feed_list
)[
param
.
col
];
int
num
=
static_cast
<
int
>
(
feed_item
.
numel
());
int
num
=
static_cast
<
int
>
(
feed_item
.
numel
());
auto
input
=
feed_item
.
data
<
float
>
();
auto
input
=
feed_item
.
data
<
T
>
();
param
.
out
->
Resize
(
feed_item
.
dims
());
param
.
out
->
Resize
(
feed_item
.
dims
());
auto
output
=
param
.
out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
output
=
param
.
out
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
VLOG
(
4
)
<<
"col: "
<<
param
.
col
<<
" num:"
<<
num
;
VLOG
(
4
)
<<
"col: "
<<
param
.
col
<<
" num:"
<<
num
;
TargetW
::
MemcpyAsync
(
TargetW
::
MemcpyAsync
(
output
,
input
,
num
*
sizeof
(
float
),
IoDirection
::
HtoD
,
stream
);
output
,
input
,
num
*
sizeof
(
T
),
IoDirection
::
HtoD
,
stream
);
}
}
}
// namespace cuda
}
// namespace cuda
...
@@ -42,8 +43,13 @@ void FeedCompute::Run() {
...
@@ -42,8 +43,13 @@ void FeedCompute::Run() {
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
typedef
paddle
::
lite
::
kernels
::
cuda
::
FeedCompute
<
float
,
PRECISION
(
kFloat
)
>
feed
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
FeedCompute
,
nchw
)
FeedFp32
;
typedef
paddle
::
lite
::
kernels
::
cuda
::
FeedCompute
<
int64_t
,
PRECISION
(
kInt64
)
>
FeedInt64
;
REGISTER_LITE_KERNEL
(
feed
,
kCUDA
,
kFloat
,
kNCHW
,
FeedFp32
,
nchw
)
.
BindInput
(
"X"
,
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
PRECISION
(
kFloat
),
...
@@ -54,8 +60,7 @@ REGISTER_LITE_KERNEL(
...
@@ -54,8 +60,7 @@ REGISTER_LITE_KERNEL(
DATALAYOUT
(
kNCHW
))})
DATALAYOUT
(
kNCHW
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
REGISTER_LITE_KERNEL
(
feed
,
kCUDA
,
kFloat
,
kNHWC
,
FeedFp32
,
nhwc
)
feed
,
kCUDA
,
kFloat
,
kNHWC
,
paddle
::
lite
::
kernels
::
cuda
::
FeedCompute
,
nhwc
)
.
BindInput
(
"X"
,
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
PRECISION
(
kFloat
),
...
@@ -65,3 +70,25 @@ REGISTER_LITE_KERNEL(
...
@@ -65,3 +70,25 @@ REGISTER_LITE_KERNEL(
PRECISION
(
kFloat
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
feed
,
kCUDA
,
kInt64
,
kNCHW
,
FeedInt64
,
nchw
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kInt64
),
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kInt64
),
DATALAYOUT
(
kNCHW
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
feed
,
kCUDA
,
kInt64
,
kNHWC
,
FeedInt64
,
nhwc
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kInt64
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kInt64
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
lite/kernels/cuda/feed_compute.h
浏览文件 @
aa67c28e
...
@@ -20,7 +20,8 @@ namespace lite {
...
@@ -20,7 +20,8 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
cuda
{
namespace
cuda
{
class
FeedCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
template
<
typename
T
,
PrecisionType
Ptype
>
class
FeedCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
Ptype
>
{
public:
public:
using
param_t
=
operators
::
FeedParam
;
using
param_t
=
operators
::
FeedParam
;
using
TargetW
=
TargetWrapper
<
TARGET
(
kCUDA
)
>
;
using
TargetW
=
TargetWrapper
<
TARGET
(
kCUDA
)
>
;
...
...
lite/kernels/cuda/match_matrix_tensor_compute.cu
浏览文件 @
aa67c28e
...
@@ -82,8 +82,32 @@ void MatchMatrixTensorCompute::Run() {
...
@@ -82,8 +82,32 @@ void MatchMatrixTensorCompute::Run() {
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
l_t_data
,
r_data
,
top_data
,
&
context
);
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
l_t_data
,
r_data
,
top_data
,
&
context
);
}
}
}
}
int
batch_size
=
x
->
lod
()[
0
].
size
()
-
1
;
int
lod_lv1_size
=
batch_size
*
dim_t
;
int
lod_lv2_size
=
x
->
lod
()[
0
].
back
()
*
dim_t
;
std
::
vector
<
size_t
>
out_lod0
(
batch_size
+
1
,
0
);
std
::
vector
<
size_t
>
out_lod1
(
lod_lv1_size
+
1
,
0
);
std
::
vector
<
size_t
>
out_lod2
(
lod_lv2_size
+
1
,
0
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
out_lod0
[
i
+
1
]
=
out_lod0
[
i
]
+
dim_t
;
int
len_l
=
offset_l
[
i
+
1
]
-
offset_l
[
i
];
for
(
int
j
=
0
;
j
<
dim_t
;
j
++
)
{
out_lod1
[
i
*
dim_t
+
j
+
1
]
=
out_lod1
[
i
*
dim_t
+
j
]
+
len_l
;
int
len_r
=
offset_r
[
i
+
1
]
-
offset_r
[
i
];
for
(
int
k
=
0
;
k
<
len_l
;
k
++
)
{
out_lod2
[
offset_l
[
i
]
*
dim_t
+
j
*
len_l
+
k
+
1
]
=
out_lod2
[
offset_l
[
i
]
*
dim_t
+
j
*
len_l
+
k
]
+
len_r
;
}
}
}
LoD
out_lod
;
LoD
out_lod
;
out_lod
.
push_back
(
top_offset
);
out_lod
.
push_back
(
top_offset
);
out_lod
.
push_back
(
offset_l
);
out_lod
.
push_back
(
offset_r
);
out
->
set_lod
(
out_lod
);
out
->
set_lod
(
out_lod
);
}
}
...
...
lite/kernels/cuda/search_aligned_mat_mul_compute.cc
浏览文件 @
aa67c28e
...
@@ -32,4 +32,7 @@ REGISTER_LITE_KERNEL(search_aligned_mat_mul,
...
@@ -32,4 +32,7 @@ REGISTER_LITE_KERNEL(search_aligned_mat_mul,
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"_a_addr"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"_b_addr"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"_c_addr"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
.
Finalize
();
lite/kernels/cuda/search_fc_compute.cu
浏览文件 @
aa67c28e
...
@@ -36,7 +36,6 @@ void anakin_NV_gemv<float>(cublasHandle_t handle,
...
@@ -36,7 +36,6 @@ void anakin_NV_gemv<float>(cublasHandle_t handle,
const
float
*
x
,
const
float
*
x
,
const
float
beta
,
const
float
beta
,
float
*
y
)
{
float
*
y
)
{
LOG
(
INFO
)
<<
"1"
;
cublasOperation_t
cuTransA
=
(
TransA
==
false
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransA
=
(
TransA
==
false
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBLAS_CHECK
(
CUBLAS_CHECK
(
cublasSgemv
(
handle
,
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
x
,
1
,
&
beta
,
y
,
1
));
cublasSgemv
(
handle
,
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
x
,
1
,
&
beta
,
y
,
1
));
...
@@ -66,17 +65,13 @@ void anakin_NV_gemm<float>(cublasHandle_t handle,
...
@@ -66,17 +65,13 @@ void anakin_NV_gemm<float>(cublasHandle_t handle,
const
float
*
B
,
const
float
*
B
,
const
float
beta
,
const
float
beta
,
float
*
C
)
{
float
*
C
)
{
LOG
(
INFO
)
<<
"1"
;
// Note that cublas follows fortran order.
// Note that cublas follows fortran order.
int
lda
=
(
!
TransA
/* == CblasNoTrans*/
)
?
K
:
M
;
int
lda
=
(
!
TransA
/* == CblasNoTrans*/
)
?
K
:
M
;
int
ldb
=
(
!
TransB
/* == CblasNoTrans*/
)
?
N
:
K
;
int
ldb
=
(
!
TransB
/* == CblasNoTrans*/
)
?
N
:
K
;
LOG
(
INFO
)
<<
"1"
;
cublasOperation_t
cuTransA
=
cublasOperation_t
cuTransA
=
(
!
TransA
/* == CblasNoTrans*/
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
(
!
TransA
/* == CblasNoTrans*/
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
LOG
(
INFO
)
<<
"1"
;
cublasOperation_t
cuTransB
=
cublasOperation_t
cuTransB
=
(
!
TransB
/* == CblasNoTrans*/
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
(
!
TransB
/* == CblasNoTrans*/
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
LOG
(
INFO
)
<<
"1"
;
CUBLAS_CHECK
(
cublasSgemm
(
handle
,
CUBLAS_CHECK
(
cublasSgemm
(
handle
,
cuTransB
,
cuTransB
,
cuTransA
,
cuTransA
,
...
@@ -91,7 +86,6 @@ void anakin_NV_gemm<float>(cublasHandle_t handle,
...
@@ -91,7 +86,6 @@ void anakin_NV_gemm<float>(cublasHandle_t handle,
&
beta
,
&
beta
,
C
,
C
,
N
));
N
));
LOG
(
INFO
)
<<
"1"
;
}
}
template
<
>
template
<
>
...
...
lite/kernels/cuda/search_group_padding_compute.cu
浏览文件 @
aa67c28e
...
@@ -46,7 +46,9 @@ __global__ void ker_search_group_padding(Dtype* out_emb_padding_data,
...
@@ -46,7 +46,9 @@ __global__ void ker_search_group_padding(Dtype* out_emb_padding_data,
in_data
[(
offset
[
seq_id
]
+
word_id_in_seq
)
*
emb_size
+
emb_id
];
in_data
[(
offset
[
seq_id
]
+
word_id_in_seq
)
*
emb_size
+
emb_id
];
}
else
{
}
else
{
out_emb_padding_data
[
tid
]
=
0.
f
;
out_emb_padding_data
[
tid
]
=
0.
f
;
out_padding_data
[
word_id
]
=
pad_id
;
if
(
emb_id
==
0
)
{
out_padding_data
[
word_id
]
=
pad_id
;
}
}
}
}
}
}
}
...
@@ -61,12 +63,7 @@ void SearchGroupPaddingCompute::Run() {
...
@@ -61,12 +63,7 @@ void SearchGroupPaddingCompute::Run() {
Tensor
*
out_new
=
param
.
out_new
;
Tensor
*
out_new
=
param
.
out_new
;
Tensor
*
out_padding
=
param
.
out_padding
;
Tensor
*
out_padding
=
param
.
out_padding
;
const
float
pad_id
=
static_cast
<
float
>
(
param
.
pad_id
);
const
float
pad_id
=
static_cast
<
float
>
(
param
.
pad_id
);
const
float
*
in_data
=
x
->
data
<
float
>
();
const
float
*
in_data
=
x
->
data
<
float
>
();
float
*
out_emb_padding_data
=
out_emb_padding
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
out_new_data
=
out_new
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
out_padding_data
=
out_padding
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
const
auto
&
in_seq_offset
=
x
->
lod
()[
0
];
const
auto
&
in_seq_offset
=
x
->
lod
()[
0
];
int
batch
=
in_seq_offset
.
size
()
-
1
;
int
batch
=
in_seq_offset
.
size
()
-
1
;
int
max_seq
=
0
;
int
max_seq
=
0
;
...
@@ -85,16 +82,20 @@ void SearchGroupPaddingCompute::Run() {
...
@@ -85,16 +82,20 @@ void SearchGroupPaddingCompute::Run() {
out_emb_padding_lod
.
push_back
(
new_offset
);
out_emb_padding_lod
.
push_back
(
new_offset
);
out_emb_padding
->
set_lod
(
out_emb_padding_lod
);
out_emb_padding
->
set_lod
(
out_emb_padding_lod
);
out_emb_padding
->
Resize
({
batch
*
max_seq
,
x_dims
[
1
]});
out_emb_padding
->
Resize
({
batch
*
max_seq
,
x_dims
[
1
]});
float
*
out_emb_padding_data
=
out_emb_padding
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
LoD
out_new_lod
;
LoD
out_new_lod
;
out_new_lod
.
push_back
(
in_seq_offset
);
out_new_lod
.
push_back
(
in_seq_offset
);
out_new
->
set_lod
(
out_new_lod
);
out_new
->
set_lod
(
out_new_lod
);
out_new
->
Resize
({
x_dims
[
0
],
1
});
out_new
->
Resize
({
x_dims
[
0
],
1
});
float
*
out_new_data
=
out_new
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
LoD
out_padding_lod
;
LoD
out_padding_lod
;
out_padding_lod
.
push_back
(
new_offset
);
out_padding_lod
.
push_back
(
new_offset
);
out_padding
->
set_lod
(
out_padding_lod
);
out_padding
->
set_lod
(
out_padding_lod
);
out_padding
->
Resize
({
batch
*
max_seq
,
1
});
out_padding
->
Resize
({
batch
*
max_seq
,
1
});
float
*
out_padding_data
=
out_padding
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
const
int
count
=
out_emb_padding
->
numel
();
const
int
count
=
out_emb_padding
->
numel
();
const
auto
&
out_emb_padding_seq_offset
=
out_emb_padding
->
lod
()[
0
];
const
auto
&
out_emb_padding_seq_offset
=
out_emb_padding
->
lod
()[
0
];
...
@@ -112,6 +113,10 @@ void SearchGroupPaddingCompute::Run() {
...
@@ -112,6 +113,10 @@ void SearchGroupPaddingCompute::Run() {
TargetWrapperCuda
::
MemsetSync
(
TargetWrapperCuda
::
MemsetSync
(
out_new_data
,
0
,
out_new
->
dims
()[
0
]
*
out_new
->
dims
()[
1
]
*
sizeof
(
float
));
out_new_data
,
0
,
out_new
->
dims
()[
0
]
*
out_new
->
dims
()[
1
]
*
sizeof
(
float
));
TargetWrapperCuda
::
MemsetSync
(
out_padding_data
,
0
,
out_padding
->
dims
()[
0
]
*
out_padding
->
dims
()[
1
]
*
sizeof
(
float
));
ker_search_group_padding
<
ker_search_group_padding
<
float
><<<
CUDA_GET_BLOCKS
(
count
),
CUDA_NUM_THREADS
,
0
,
cuda_stream
>>>
(
float
><<<
CUDA_GET_BLOCKS
(
count
),
CUDA_NUM_THREADS
,
0
,
cuda_stream
>>>
(
...
...
lite/kernels/cuda/search_seq_depadding_compute.cu
浏览文件 @
aa67c28e
...
@@ -50,6 +50,7 @@ void SearchSeqDepaddingCompute::Run() {
...
@@ -50,6 +50,7 @@ void SearchSeqDepaddingCompute::Run() {
auto
*
out
=
param
.
out
;
auto
*
out
=
param
.
out
;
auto
*
in_data
=
pad
->
data
<
float
>
();
auto
*
in_data
=
pad
->
data
<
float
>
();
out
->
Resize
({
src
->
dims
()[
0
],
pad
->
dims
()[
1
]});
auto
*
out_data
=
out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
out_data
=
out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
const
int
count
=
out
->
numel
();
const
int
count
=
out
->
numel
();
...
@@ -59,6 +60,9 @@ void SearchSeqDepaddingCompute::Run() {
...
@@ -59,6 +60,9 @@ void SearchSeqDepaddingCompute::Run() {
int
seq_num
=
pad_seq_offset
.
size
()
-
1
;
int
seq_num
=
pad_seq_offset
.
size
()
-
1
;
int
emb_size
=
pad
->
dims
()[
1
];
int
emb_size
=
pad
->
dims
()[
1
];
LoD
out_lod
;
out_lod
.
push_back
(
src_seq_offset
);
out
->
set_lod
(
out_lod
);
std
::
vector
<
int
>
seq_id_map
;
std
::
vector
<
int
>
seq_id_map
;
for
(
int
i
=
0
;
i
<
seq_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
seq_num
;
i
++
)
{
int
cur_len
=
src_seq_offset
[
i
+
1
]
-
src_seq_offset
[
i
];
int
cur_len
=
src_seq_offset
[
i
+
1
]
-
src_seq_offset
[
i
];
...
@@ -77,11 +81,12 @@ void SearchSeqDepaddingCompute::Run() {
...
@@ -77,11 +81,12 @@ void SearchSeqDepaddingCompute::Run() {
cuda_stream
);
cuda_stream
);
int
threads
=
512
;
int
threads
=
512
;
ker_sequence_depadding_fwd
<<<
count
,
threads
,
0
,
cuda_stream
>>>
(
int
blocks
=
(
count
+
threads
-
1
)
/
threads
;
ker_sequence_depadding_fwd
<<<
blocks
,
threads
,
0
,
cuda_stream
>>>
(
out_data
,
in_data
,
seq_id_map_data
,
seq_num
,
max_len
,
emb_size
,
count
);
out_data
,
in_data
,
seq_id_map_data
,
seq_num
,
max_len
,
emb_size
,
count
);
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
}
}
}
// namespace cuda
}
// namespace cuda
...
...
lite/kernels/cuda/sequence_arithmetic_compute.cu
浏览文件 @
aa67c28e
...
@@ -120,7 +120,7 @@ void SequenceArithmeticCompute::Run() {
...
@@ -120,7 +120,7 @@ void SequenceArithmeticCompute::Run() {
auto
x_data
=
param
.
X
->
data
<
float
>
();
auto
x_data
=
param
.
X
->
data
<
float
>
();
auto
x_lod
=
param
.
X
->
lod
()[
0
];
auto
x_lod
=
param
.
X
->
lod
()[
0
];
auto
y_data
=
param
.
X
->
data
<
float
>
();
auto
y_data
=
param
.
Y
->
data
<
float
>
();
auto
y_lod
=
param
.
Y
->
lod
()[
0
];
auto
y_lod
=
param
.
Y
->
lod
()[
0
];
auto
out_data
=
param
.
Out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
out_data
=
param
.
Out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
...
@@ -174,7 +174,6 @@ void SequenceArithmeticCompute::Run() {
...
@@ -174,7 +174,6 @@ void SequenceArithmeticCompute::Run() {
int
seq_num
=
x_lod
.
size
()
-
1
;
int
seq_num
=
x_lod
.
size
()
-
1
;
int
count
=
param
.
X
->
numel
();
int
count
=
param
.
X
->
numel
();
int
inner_size
=
param
.
X
->
dims
()[
1
];
int
inner_size
=
param
.
X
->
dims
()[
1
];
switch
(
param
.
op_type
)
{
switch
(
param
.
op_type
)
{
case
1
:
// sum
case
1
:
// sum
ker_arithmetic_sum
<
ker_arithmetic_sum
<
...
...
lite/kernels/cuda/sequence_concat_compute.cu
浏览文件 @
aa67c28e
...
@@ -24,6 +24,24 @@ namespace cuda {
...
@@ -24,6 +24,24 @@ namespace cuda {
const
int
CUDA_NUM_THREADS
=
512
;
const
int
CUDA_NUM_THREADS
=
512
;
template
<
typename
T
>
inline
LoD
ConcatLoD
(
const
std
::
vector
<
lite
::
Tensor
*>&
xs
)
{
std
::
vector
<
size_t
>
result
;
result
.
resize
(
xs
[
0
]
->
lod
()[
0
].
size
());
for
(
size_t
i
=
1
;
i
<
result
.
size
();
++
i
)
{
size_t
sum
=
0
;
for
(
size_t
j
=
0
;
j
<
xs
.
size
();
++
j
)
{
auto
&
x_lod
=
xs
[
j
]
->
lod
()[
0
];
sum
+=
x_lod
[
i
];
}
result
[
i
]
=
sum
;
}
LoD
lod
;
lod
.
emplace_back
(
result
);
return
lod
;
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
__global__
void
ker_sequence_concat
(
Dtype
*
out_data
,
__global__
void
ker_sequence_concat
(
Dtype
*
out_data
,
const
uint64_t
*
in_locate_data
,
const
uint64_t
*
in_locate_data
,
...
@@ -96,6 +114,8 @@ void SequenceConcatCompute::Run() {
...
@@ -96,6 +114,8 @@ void SequenceConcatCompute::Run() {
IoDirection
::
HtoD
,
IoDirection
::
HtoD
,
stream
);
stream
);
param
.
Out
->
set_lod
(
ConcatLoD
<
float
>
(
param
.
X
));
int
count
=
param
.
X
[
0
]
->
numel
();
int
count
=
param
.
X
[
0
]
->
numel
();
for
(
int
i
=
1
;
i
<
param
.
X
.
size
();
++
i
)
{
for
(
int
i
=
1
;
i
<
param
.
X
.
size
();
++
i
)
{
count
+=
param
.
X
[
i
]
->
numel
();
count
+=
param
.
X
[
i
]
->
numel
();
...
...
lite/kernels/cuda/sequence_pool_compute.cu
浏览文件 @
aa67c28e
...
@@ -254,4 +254,5 @@ REGISTER_LITE_KERNEL(sequence_pool,
...
@@ -254,4 +254,5 @@ REGISTER_LITE_KERNEL(sequence_pool,
def
)
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"MaxIndex"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
.
Finalize
();
lite/kernels/cuda/sequence_reverse_compute.cu
浏览文件 @
aa67c28e
...
@@ -42,11 +42,9 @@ __host__ __device__ inline size_t UpperBound(const T* x,
...
@@ -42,11 +42,9 @@ __host__ __device__ inline size_t UpperBound(const T* x,
return
static_cast
<
size_t
>
(
first
-
x
);
return
static_cast
<
size_t
>
(
first
-
x
);
}
}
__global__
void
SequenceReverseKernelGridIsOne
(
const
float
*
x
,
template
<
typename
T
>
float
*
y
,
__global__
void
SequenceReverseKernelGridIsOne
(
const
int64_t
*
lod
,
const
T
*
x
,
T
*
y
,
const
int64_t
*
lod
,
size_t
lod_count
,
int64_t
row_numel
)
{
size_t
lod_count
,
int64_t
row_numel
)
{
int64_t
idx
=
static_cast
<
int64_t
>
(
threadIdx
.
x
);
int64_t
idx
=
static_cast
<
int64_t
>
(
threadIdx
.
x
);
auto
row_idx_x
=
idx
/
row_numel
;
auto
row_idx_x
=
idx
/
row_numel
;
auto
lod_idx
=
UpperBound
(
lod
,
lod_count
,
row_idx_x
);
auto
lod_idx
=
UpperBound
(
lod
,
lod_count
,
row_idx_x
);
...
@@ -55,8 +53,9 @@ __global__ void SequenceReverseKernelGridIsOne(const float* x,
...
@@ -55,8 +53,9 @@ __global__ void SequenceReverseKernelGridIsOne(const float* x,
y
[
idx_y
]
=
x
[
idx
];
y
[
idx_y
]
=
x
[
idx
];
}
}
__global__
void
SequenceReverseKernel
(
const
float
*
x
,
template
<
typename
T
>
float
*
y
,
__global__
void
SequenceReverseKernel
(
const
T
*
x
,
T
*
y
,
const
int64_t
*
lod
,
const
int64_t
*
lod
,
size_t
lod_count
,
size_t
lod_count
,
int64_t
row_numel
,
int64_t
row_numel
,
...
@@ -71,19 +70,20 @@ __global__ void SequenceReverseKernel(const float* x,
...
@@ -71,19 +70,20 @@ __global__ void SequenceReverseKernel(const float* x,
}
}
}
}
void
SequenceReverseCompute
::
Run
()
{
template
<
typename
T
,
PrecisionType
Ptype
>
auto
&
param
=
this
->
Param
<
param_t
>
();
void
SequenceReverseCompute
<
T
,
Ptype
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
param_t
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
auto
stream
=
ctx
.
exec_stream
();
size_t
limit
=
static_cast
<
size_t
>
(
param
.
X
->
numel
());
size_t
limit
=
static_cast
<
size_t
>
(
param
.
X
->
numel
());
int64_t
row_numel
=
static_cast
<
int64_t
>
(
limit
/
param
.
X
->
dims
()[
0
]);
int64_t
row_numel
=
static_cast
<
int64_t
>
(
limit
/
param
.
X
->
dims
()[
0
]);
const
auto
*
x_data
=
param
.
X
->
data
<
float
>
();
const
auto
*
x_data
=
param
.
X
->
template
data
<
T
>();
auto
y_data
=
param
.
Out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
y_data
=
param
.
Out
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
CHECK_NE
(
x_data
,
y_data
)
CHECK_NE
(
x_data
,
y_data
)
<<
"SequenceReverse Op does not support in-place operation"
;
<<
"SequenceReverse Op does not support in-place operation"
;
const
auto
lod
=
param
.
X
->
lod
()[
param
.
X
->
lod
().
size
()
-
1
];
const
auto
lod
=
param
.
X
->
lod
()[
param
.
X
->
lod
().
size
()
-
1
];
const
size_t
lod_count
=
lod
.
size
();
const
size_t
lod_count
=
lod
.
size
();
param
.
Out
->
set_lod
(
param
.
X
->
lod
());
lod_cuda
.
Resize
({
static_cast
<
int64_t
>
(
lod
.
size
())});
lod_cuda
.
Resize
({
static_cast
<
int64_t
>
(
lod
.
size
())});
int64_t
*
lod_data
=
lod_cuda
.
mutable_data
<
int64_t
>
(
TARGET
(
kCUDA
));
int64_t
*
lod_data
=
lod_cuda
.
mutable_data
<
int64_t
>
(
TARGET
(
kCUDA
));
...
@@ -92,11 +92,9 @@ void SequenceReverseCompute::Run() {
...
@@ -92,11 +92,9 @@ void SequenceReverseCompute::Run() {
sizeof
(
int64_t
)
*
lod
.
size
(),
sizeof
(
int64_t
)
*
lod
.
size
(),
IoDirection
::
HtoD
,
IoDirection
::
HtoD
,
stream
);
stream
);
constexpr
int
num_threads
=
1024
;
constexpr
int
num_threads
=
1024
;
int
block_size
=
limit
<=
num_threads
?
limit
:
num_threads
;
int
block_size
=
limit
<=
num_threads
?
limit
:
num_threads
;
int
grid_size
=
(
limit
+
num_threads
-
1
)
/
num_threads
;
int
grid_size
=
(
limit
+
num_threads
-
1
)
/
num_threads
;
if
(
grid_size
==
1
)
{
if
(
grid_size
==
1
)
{
SequenceReverseKernelGridIsOne
<<<
1
,
block_size
,
0
,
stream
>>>
(
SequenceReverseKernelGridIsOne
<<<
1
,
block_size
,
0
,
stream
>>>
(
x_data
,
y_data
,
lod_data
,
lod_count
,
row_numel
);
x_data
,
y_data
,
lod_data
,
lod_count
,
row_numel
);
...
@@ -104,7 +102,6 @@ void SequenceReverseCompute::Run() {
...
@@ -104,7 +102,6 @@ void SequenceReverseCompute::Run() {
SequenceReverseKernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
SequenceReverseKernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
x_data
,
y_data
,
lod_data
,
lod_count
,
row_numel
,
limit
);
x_data
,
y_data
,
lod_data
,
lod_count
,
row_numel
,
limit
);
}
}
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
}
...
@@ -114,12 +111,20 @@ void SequenceReverseCompute::Run() {
...
@@ -114,12 +111,20 @@ void SequenceReverseCompute::Run() {
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
sequence_reverse
,
typedef
paddle
::
lite
::
kernels
::
cuda
::
SequenceReverseCompute
<
float
,
kCUDA
,
PRECISION
(
kFloat
)
>
kFloat
,
ReverseFp32
;
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
SequenceReverseCompute
,
typedef
paddle
::
lite
::
kernels
::
cuda
::
SequenceReverseCompute
<
int64_t
,
def
)
PRECISION
(
kInt64
)
>
ReverseInt64
;
REGISTER_LITE_KERNEL
(
sequence_reverse
,
kCUDA
,
kFloat
,
kNCHW
,
ReverseFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
sequence_reverse
,
kCUDA
,
kInt64
,
kNCHW
,
ReverseInt64
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kInt64
))})
.
BindOutput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kInt64
))})
.
Finalize
();
lite/kernels/cuda/sequence_reverse_compute.h
浏览文件 @
aa67c28e
...
@@ -20,8 +20,8 @@ namespace lite {
...
@@ -20,8 +20,8 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
cuda
{
namespace
cuda
{
class
SequenceReverseCompute
template
<
typename
T
,
PrecisionType
Ptype
>
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
class
SequenceReverseCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
Ptype
>
{
public:
public:
using
param_t
=
operators
::
SequenceReverseParam
;
using
param_t
=
operators
::
SequenceReverseParam
;
...
...
lite/kernels/cuda/sequence_reverse_compute_test.cc
浏览文件 @
aa67c28e
...
@@ -40,7 +40,7 @@ static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) {
...
@@ -40,7 +40,7 @@ static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) {
}
}
TEST
(
sequence_reverse_cuda
,
normal
)
{
TEST
(
sequence_reverse_cuda
,
normal
)
{
SequenceReverseCompute
seq_kernel
;
SequenceReverseCompute
<
float
,
PRECISION
(
kFloat
)
>
seq_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
...
...
lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu
浏览文件 @
aa67c28e
...
@@ -26,8 +26,6 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
...
@@ -26,8 +26,6 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
const
Dtype
*
input
,
const
Dtype
*
input
,
const
int
*
gpu_input_offset_l
,
const
int
*
gpu_input_offset_l
,
const
int
*
gpu_input_offset_r
,
const
int
*
gpu_input_offset_r
,
const
int
row_max
,
const
int
col_max
,
const
int
topk_size
,
const
int
topk_size
,
const
int
*
topks
,
const
int
*
topks
,
const
int
feat_map_num
)
{
const
int
feat_map_num
)
{
...
@@ -35,17 +33,20 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
...
@@ -35,17 +33,20 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
gpu_input_offset_l
[
blockIdx
.
x
+
1
]
-
gpu_input_offset_l
[
blockIdx
.
x
];
// 8
gpu_input_offset_l
[
blockIdx
.
x
+
1
]
-
gpu_input_offset_l
[
blockIdx
.
x
];
// 8
int
col
=
gpu_input_offset_r
[
blockIdx
.
x
+
1
]
-
int
col
=
gpu_input_offset_r
[
blockIdx
.
x
+
1
]
-
gpu_input_offset_r
[
blockIdx
.
x
];
// 30
gpu_input_offset_r
[
blockIdx
.
x
];
// 30
int
max_k
=
topks
[
topk_size
-
1
];
int
max_k
=
topks
[
topk_size
-
1
];
max_k
=
max_k
<
col
?
max_k
:
col
;
max_k
=
max_k
<
col
?
max_k
:
col
;
extern
__shared__
Dtype
smem
[];
// H*W
extern
__shared__
Dtype
smem
[];
// H*W
const
Dtype
*
fm_row_in_data
=
input
+
const
Dtype
*
fm_row_in_data
=
input
;
blockIdx
.
x
*
row_max
*
feat_map_num
*
col_max
+
for
(
int
i
=
0
;
i
<
blockIdx
.
x
;
++
i
)
{
blockIdx
.
y
*
row_max
*
col_max
;
int
tmp_row
=
gpu_input_offset_l
[
i
+
1
]
-
gpu_input_offset_l
[
i
];
int
tmp_col
=
gpu_input_offset_r
[
i
+
1
]
-
gpu_input_offset_r
[
i
];
fm_row_in_data
+=
tmp_row
*
feat_map_num
*
tmp_col
;
}
fm_row_in_data
+=
blockIdx
.
y
*
row
*
col
;
for
(
int
i
=
threadIdx
.
x
;
i
<
row
*
col
_max
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
row
*
col
;
i
+=
blockDim
.
x
)
{
smem
[
i
]
=
fm_row_in_data
[
i
];
smem
[
i
]
=
fm_row_in_data
[
i
];
}
}
__syncthreads
();
__syncthreads
();
...
@@ -56,7 +57,7 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
...
@@ -56,7 +57,7 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
(
gpu_input_offset_l
[
blockIdx
.
x
]
+
idx
)
*
feat_map_num
*
topk_size
+
(
gpu_input_offset_l
[
blockIdx
.
x
]
+
idx
)
*
feat_map_num
*
topk_size
+
blockIdx
.
y
*
topk_size
;
blockIdx
.
y
*
topk_size
;
Dtype
*
smem_start_col
=
smem
+
idx
*
col
_max
;
Dtype
*
smem_start_col
=
smem
+
idx
*
col
;
int
counter
=
max_k
;
// topk_size;
int
counter
=
max_k
;
// topk_size;
Dtype
last_max_val
=
-
20000.0
;
Dtype
last_max_val
=
-
20000.0
;
...
@@ -75,7 +76,7 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
...
@@ -75,7 +76,7 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
if
(
max_val
<
-
9999.0
)
{
// == -10000.0
if
(
max_val
<
-
9999.0
)
{
// == -10000.0
max_val
=
last_max_val
;
max_val
=
last_max_val
;
}
}
smem_start_col
[
max_pos
]
=
10000000.0
;
smem_start_col
[
max_pos
]
=
-
10000000.0
;
int
i
=
max_k
-
counter
;
int
i
=
max_k
-
counter
;
for
(
int
c
=
0
;
c
<
topk_size
;
c
++
)
{
for
(
int
c
=
0
;
c
<
topk_size
;
c
++
)
{
if
(
i
<=
topks
[
c
]
-
1
)
{
if
(
i
<=
topks
[
c
]
-
1
)
{
...
@@ -97,7 +98,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
...
@@ -97,7 +98,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
cuda_stream
=
ctx
.
exec_stream
();
auto
cuda_stream
=
ctx
.
exec_stream
();
int
topk_num
=
param
.
topks
.
size
();
int
topk_num
=
param
.
topks
.
size
();
lite
::
DDim
top_ks_shape
(
std
::
vector
<
int64_t
>
{
topk_num
,
1
,
1
,
1
});
lite
::
DDim
top_ks_shape
(
std
::
vector
<
int64_t
>
{
topk_num
,
1
,
1
,
1
});
_top_ks
.
Resize
(
top_ks_shape
);
_top_ks
.
Resize
(
top_ks_shape
);
...
@@ -107,12 +107,16 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
...
@@ -107,12 +107,16 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
cudaMemcpyHostToDevice
,
cudaMemcpyHostToDevice
,
cuda_stream
);
cuda_stream
);
int
width_offset_len
=
param
.
X
->
lod
()[
0
].
size
();
int
width_offset_len
=
param
.
COLUMN
->
lod
()[
0
].
size
();
lite
::
DDim
width_offset_shape
(
lite
::
DDim
width_offset_shape
(
std
::
vector
<
int64_t
>
{
width_offset_len
,
1
,
1
,
1
});
std
::
vector
<
int64_t
>
{
width_offset_len
,
1
,
1
,
1
});
_width_offset
.
Resize
(
width_offset_shape
);
_width_offset
.
Resize
(
width_offset_shape
);
std
::
vector
<
int
>
width_lod_0
(
width_offset_len
,
0
);
for
(
size_t
i
=
0
;
i
<
param
.
COLUMN
->
lod
()[
0
].
size
();
++
i
)
{
width_lod_0
[
i
]
=
static_cast
<
int
>
(
param
.
COLUMN
->
lod
()[
0
][
i
]);
}
cudaMemcpyAsync
(
_width_offset
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
)),
cudaMemcpyAsync
(
_width_offset
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
)),
&
(
param
.
X
->
lod
()[
0
][
0
])
,
&
width_lod_0
[
0
]
,
sizeof
(
int
)
*
width_offset_len
,
sizeof
(
int
)
*
width_offset_len
,
cudaMemcpyHostToDevice
,
cudaMemcpyHostToDevice
,
cuda_stream
);
cuda_stream
);
...
@@ -121,8 +125,12 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
...
@@ -121,8 +125,12 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
lite
::
DDim
height_offset_shape
(
lite
::
DDim
height_offset_shape
(
std
::
vector
<
int64_t
>
{
height_offset_len
,
1
,
1
,
1
});
std
::
vector
<
int64_t
>
{
height_offset_len
,
1
,
1
,
1
});
_height_offset
.
Resize
(
height_offset_shape
);
_height_offset
.
Resize
(
height_offset_shape
);
std
::
vector
<
int
>
height_lod_0
(
height_offset_len
,
0
);
for
(
size_t
i
=
0
;
i
<
param
.
ROW
->
lod
()[
0
].
size
();
++
i
)
{
height_lod_0
[
i
]
=
static_cast
<
int
>
(
param
.
ROW
->
lod
()[
0
][
i
]);
}
cudaMemcpyAsync
(
_height_offset
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
)),
cudaMemcpyAsync
(
_height_offset
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
)),
&
(
param
.
ROW
->
lod
()[
0
][
0
])
,
&
height_lod_0
[
0
]
,
sizeof
(
int
)
*
height_offset_len
,
sizeof
(
int
)
*
height_offset_len
,
cudaMemcpyHostToDevice
,
cudaMemcpyHostToDevice
,
cuda_stream
);
cuda_stream
);
...
@@ -136,16 +144,20 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
...
@@ -136,16 +144,20 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
sizeof
(
T
)
*
out_tensor
->
numel
(),
sizeof
(
T
)
*
out_tensor
->
numel
(),
cuda_stream
);
cuda_stream
);
auto
x_dims
=
x_tensor
->
dims
();
int
num
=
param
.
ROW
->
lod
()[
0
].
size
()
-
1
;
int
num
=
x_dims
[
0
];
int
channel
=
param
.
channel_num
;
int
channel
=
x_dims
[
1
];
int
height
=
x_dims
[
2
];
int
width
=
x_dims
[
3
];
const
int
*
height_offset
=
_height_offset
.
data
<
int
>
();
const
int
*
height_offset
=
_height_offset
.
data
<
int
>
();
const
int
*
width_offset
=
_width_offset
.
data
<
int
>
();
const
int
*
width_offset
=
_width_offset
.
data
<
int
>
();
int
feat_map_size
=
height
*
width
;
int
feat_map_size
=
0
;
for
(
size_t
i
=
0
;
i
<
height_lod_0
.
size
()
-
1
;
++
i
)
{
int
height
=
height_lod_0
[
i
+
1
]
-
height_lod_0
[
i
];
int
width
=
width_lod_0
[
i
+
1
]
-
width_lod_0
[
i
];
if
(
height
*
width
>
feat_map_size
)
{
feat_map_size
=
height
*
width
;
}
}
dim3
blocks
(
num
,
channel
);
dim3
blocks
(
num
,
channel
);
dim3
threads
(
32
,
1
);
dim3
threads
(
32
,
1
);
topk_avg_pooling_kernel_by_row_improve
<
topk_avg_pooling_kernel_by_row_improve
<
...
@@ -154,11 +166,12 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
...
@@ -154,11 +166,12 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
in_data
,
in_data
,
height_offset
,
height_offset
,
width_offset
,
width_offset
,
height
,
width
,
param
.
topks
.
size
(),
param
.
topks
.
size
(),
_top_ks
.
data
<
int
>
(),
_top_ks
.
data
<
int
>
(),
param
.
channel_num
);
param
.
channel_num
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
}
}
}
// namespace cuda
}
// namespace cuda
...
...
lite/kernels/cuda/softmax_compute.cu
浏览文件 @
aa67c28e
...
@@ -173,9 +173,10 @@ void SoftmaxCompute::Run() {
...
@@ -173,9 +173,10 @@ void SoftmaxCompute::Run() {
cudaGetDeviceProperties
(
&
deviceProp
,
device_id
);
cudaGetDeviceProperties
(
&
deviceProp
,
device_id
);
size_t
sharedmem_size
=
deviceProp
.
sharedMemPerBlock
;
size_t
sharedmem_size
=
deviceProp
.
sharedMemPerBlock
;
int
max_dimsize
=
sharedmem_size
/
sizeof
(
float
)
/
threads
;
int
max_dimsize
=
sharedmem_size
/
sizeof
(
float
)
/
threads
;
auto
input_data
=
param
.
x
->
data
<
float
>
();
auto
input_data
=
param
.
x
->
data
<
float
>
();
auto
output_data
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
output_data
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemsetSync
(
output_data
,
0
,
param
.
output
->
numel
()
*
sizeof
(
float
));
if
(
axis_size
<=
max_dimsize
)
{
if
(
axis_size
<=
max_dimsize
)
{
int
use_sharemem_size
=
axis_size
*
threads
*
sizeof
(
float
);
int
use_sharemem_size
=
axis_size
*
threads
*
sizeof
(
float
);
sharemem_softmax_kernel
<<<
blocks
,
threads
,
use_sharemem_size
,
stream
>>>
(
sharemem_softmax_kernel
<<<
blocks
,
threads
,
use_sharemem_size
,
stream
>>>
(
...
@@ -194,7 +195,7 @@ void SoftmaxCompute::Run() {
...
@@ -194,7 +195,7 @@ void SoftmaxCompute::Run() {
auto
max_data
=
tmax_data
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
max_data
=
tmax_data
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
sum_data
=
tsum_data
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
sum_data
=
tsum_data
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
//! firstly, get maximum data
//! firstly, get maximum data
float
min_data
=
std
::
numeric_limits
<
float
>::
min
();
float
min_data
=
std
::
numeric_limits
<
float
>::
lowest
();
softmax_max_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
total_threads
,
softmax_max_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
total_threads
,
input_data
,
input_data
,
max_data
,
max_data
,
...
@@ -217,7 +218,7 @@ void SoftmaxCompute::Run() {
...
@@ -217,7 +218,7 @@ void SoftmaxCompute::Run() {
total_threads
,
output_data
,
sum_data
,
inner_num
,
outer_num
,
axis_size
);
total_threads
,
output_data
,
sum_data
,
inner_num
,
outer_num
,
axis_size
);
}
}
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
}
}
}
// namespace cuda
}
// namespace cuda
...
@@ -258,4 +259,5 @@ REGISTER_LITE_KERNEL(search_seq_softmax,
...
@@ -258,4 +259,5 @@ REGISTER_LITE_KERNEL(search_seq_softmax,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out_log"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
.
Finalize
();
lite/operators/sequence_topk_avg_pooling_op.cc
浏览文件 @
aa67c28e
...
@@ -54,8 +54,7 @@ bool SequenceTopkAvgPoolingOpLite::InferShape() const {
...
@@ -54,8 +54,7 @@ bool SequenceTopkAvgPoolingOpLite::InferShape() const {
vec_out_shape
.
push_back
(
channel_num
*
num_k
);
vec_out_shape
.
push_back
(
channel_num
*
num_k
);
param_
.
Out
->
Resize
(
lite
::
DDim
(
vec_out_shape
));
param_
.
Out
->
Resize
(
lite
::
DDim
(
vec_out_shape
));
auto
out_lod
=
param_
.
Out
->
mutable_lod
();
param_
.
Out
->
set_lod
(
param_
.
ROW
->
lod
());
*
out_lod
=
param_
.
X
->
lod
();
return
true
;
return
true
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录