Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
8f593443
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8f593443
编写于
12月 25, 2019
作者:
W
Wilber
提交者:
GitHub
12月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize softmax cuda kernel test=develop (#2660)
optimize softmax cuda kernel
上级
00fee283
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
18 addition
and
18 deletion
+18
-18
lite/kernels/cuda/softmax_compute.cu
lite/kernels/cuda/softmax_compute.cu
+13
-15
lite/kernels/cuda/softmax_compute.h
lite/kernels/cuda/softmax_compute.h
+5
-3
未找到文件。
lite/kernels/cuda/softmax_compute.cu
浏览文件 @
8f593443
...
@@ -156,8 +156,8 @@ void SoftmaxCompute::PrepareForRun() {
...
@@ -156,8 +156,8 @@ void SoftmaxCompute::PrepareForRun() {
cudaGetDevice
(
&
device_id
);
cudaGetDevice
(
&
device_id
);
cudaDeviceProp
deviceProp
;
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
device_id
);
cudaGetDeviceProperties
(
&
deviceProp
,
device_id
);
sharedmem_size
=
deviceProp
.
sharedMemPerBlock
;
sharedmem_size
_
=
deviceProp
.
sharedMemPerBlock
;
max_dimsize
=
sharedmem_size
/
sizeof
(
float
)
/
CUDA_NUM_THREADS
;
max_dimsize
_
=
sharedmem_size_
/
sizeof
(
float
)
/
CUDA_NUM_THREADS
;
}
}
void
SoftmaxCompute
::
Run
()
{
void
SoftmaxCompute
::
Run
()
{
...
@@ -174,29 +174,27 @@ void SoftmaxCompute::Run() {
...
@@ -174,29 +174,27 @@ void SoftmaxCompute::Run() {
int
outer_num
=
x_dims
.
Slice
(
0
,
axis
).
production
();
int
outer_num
=
x_dims
.
Slice
(
0
,
axis
).
production
();
int
inner_num
=
x_dims
.
Slice
(
axis
+
1
,
x_rank
).
production
();
int
inner_num
=
x_dims
.
Slice
(
axis
+
1
,
x_rank
).
production
();
int
total_threads
=
inner_num
*
outer_num
;
int
total_threads
=
inner_num
*
outer_num
;
int
axis_size
=
x_dims
[
axis
];
axis_size_
=
x_dims
[
axis
];
const
int
threads
=
CUDA_NUM_THREADS
;
const
int
threads
=
CUDA_NUM_THREADS
;
const
int
blocks
=
(
total_threads
+
threads
-
1
)
/
threads
;
const
int
blocks
=
(
total_threads
+
threads
-
1
)
/
threads
;
auto
input_data
=
param
.
x
->
data
<
float
>
();
auto
input_data
=
param
.
x
->
data
<
float
>
();
auto
output_data
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
output_data
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
if
(
axis_size
<=
max_dimsize
)
{
if
(
axis_size
_
<=
max_dimsize_
)
{
int
use_sharemem_size
=
axis_size
*
threads
*
sizeof
(
float
);
int
use_sharemem_size
=
axis_size
_
*
threads
*
sizeof
(
float
);
sharemem_softmax_kernel
<<<
blocks
,
threads
,
use_sharemem_size
,
stream
>>>
(
sharemem_softmax_kernel
<<<
blocks
,
threads
,
use_sharemem_size
,
stream
>>>
(
total_threads
,
total_threads
,
input_data
,
input_data
,
output_data
,
output_data
,
inner_num
,
inner_num
,
outer_num
,
outer_num
,
axis_size
);
axis_size
_
);
}
else
{
}
else
{
//! re_alloc device memory
//! re_alloc device memory
Tensor
tmax_data
;
tmax_data_
.
Resize
({
1
,
1
,
1
,
outer_num
*
inner_num
});
Tensor
tsum_data
;
tsum_data_
.
Resize
({
1
,
1
,
1
,
outer_num
*
inner_num
});
tmax_data
.
Resize
({
1
,
1
,
1
,
outer_num
*
inner_num
});
auto
max_data
=
tmax_data_
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
tsum_data
.
Resize
({
1
,
1
,
1
,
outer_num
*
inner_num
});
auto
sum_data
=
tsum_data_
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
max_data
=
tmax_data
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
sum_data
=
tsum_data
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
//! firstly, get maximum data
//! firstly, get maximum data
float
min_data
=
std
::
numeric_limits
<
float
>::
lowest
();
float
min_data
=
std
::
numeric_limits
<
float
>::
lowest
();
softmax_max_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
total_threads
,
softmax_max_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
total_threads
,
...
@@ -205,7 +203,7 @@ void SoftmaxCompute::Run() {
...
@@ -205,7 +203,7 @@ void SoftmaxCompute::Run() {
min_data
,
min_data
,
inner_num
,
inner_num
,
outer_num
,
outer_num
,
axis_size
);
axis_size
_
);
//! then, compute exp and sum data
//! then, compute exp and sum data
softmax_sub_exp_sum_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
softmax_sub_exp_sum_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
total_threads
,
total_threads
,
...
@@ -215,10 +213,10 @@ void SoftmaxCompute::Run() {
...
@@ -215,10 +213,10 @@ void SoftmaxCompute::Run() {
sum_data
,
sum_data
,
inner_num
,
inner_num
,
outer_num
,
outer_num
,
axis_size
);
axis_size
_
);
//! last, compute divided output
//! last, compute divided output
softmax_divid_output_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
softmax_divid_output_kernel
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
total_threads
,
output_data
,
sum_data
,
inner_num
,
outer_num
,
axis_size
);
total_threads
,
output_data
,
sum_data
,
inner_num
,
outer_num
,
axis_size
_
);
}
}
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
...
...
lite/kernels/cuda/softmax_compute.h
浏览文件 @
8f593443
...
@@ -30,9 +30,11 @@ class SoftmaxCompute
...
@@ -30,9 +30,11 @@ class SoftmaxCompute
virtual
~
SoftmaxCompute
()
=
default
;
virtual
~
SoftmaxCompute
()
=
default
;
private:
private:
size_t
sharedmem_size
;
lite
::
Tensor
tmax_data_
;
int
num_threads
;
lite
::
Tensor
tsum_data_
;
int
max_dimsize
;
size_t
sharedmem_size_
;
int
max_dimsize_
;
int
axis_size_
;
};
};
}
// namespace cuda
}
// namespace cuda
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录