Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
192be07b
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
192be07b
编写于
8月 14, 2020
作者:
W
Wilber
提交者:
GitHub
8月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize sequence_mask. test=develop (#4120)
上级
9e38adc8
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
73 addition
and
7 deletion
+73
-7
lite/kernels/cuda/sequence_mask_compute.cu
lite/kernels/cuda/sequence_mask_compute.cu
+62
-5
lite/kernels/cuda/sequence_mask_compute.h
lite/kernels/cuda/sequence_mask_compute.h
+3
-0
lite/operators/fc_op.cc
lite/operators/fc_op.cc
+8
-2
未找到文件。
lite/kernels/cuda/sequence_mask_compute.cu
浏览文件 @
192be07b
...
...
@@ -37,6 +37,40 @@ __global__ void SequenceMaskKernel(T* dst,
}
}
template
<
typename
T
>
__global__
void
VecMaxKernel
(
const
T
*
in_data
,
T
*
out
,
const
int
count
)
{
extern
__shared__
T
cache
[];
int
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
cache_index
=
threadIdx
.
x
;
T
tmp
=
-
1
;
while
(
i
<
count
)
{
if
(
in_data
[
i
]
>
tmp
)
{
tmp
=
in_data
[
i
];
}
i
+=
blockDim
.
x
*
gridDim
.
x
;
}
cache
[
cache_index
]
=
tmp
;
__syncthreads
();
// perform parallel reduction, blockDim.x must be 2^n
int
ib
=
blockDim
.
x
/
2
;
while
(
ib
!=
0
)
{
if
(
cache_index
<
ib
&&
cache
[
cache_index
+
ib
]
>
cache
[
cache_index
])
{
cache
[
cache_index
]
=
cache
[
cache_index
+
ib
];
}
__syncthreads
();
ib
/=
2
;
}
if
(
cache_index
==
0
)
{
out
[
blockIdx
.
x
]
=
cache
[
0
];
}
}
template
<
typename
T
,
PrecisionType
Ptype
>
void
SequenceMaskCompute
<
T
,
Ptype
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
param_t
>();
...
...
@@ -57,11 +91,34 @@ void SequenceMaskCompute<T, Ptype>::Run() {
}
if
(
maxlen
<
0
)
{
maxlen
=
static_cast
<
int
>
(
thrust
::
reduce
(
thrust
::
device_pointer_cast
(
x_data
),
thrust
::
device_pointer_cast
(
x_data
)
+
x
->
numel
(),
static_cast
<
int64_t
>
(
0
),
thrust
::
maximum
<
int64_t
>
()));
// choose algorithm according to magic_num.
const
int
magic_num
=
256
;
std
::
vector
<
int64_t
>
h_max_data
;
if
(
x
->
numel
()
<
magic_num
)
{
h_max_data
.
resize
(
x
->
numel
());
TargetWrapperCuda
::
MemcpySync
(
h_max_data
.
data
(),
x_data
,
x
->
numel
()
*
sizeof
(
int64_t
),
IoDirection
::
DtoH
);
}
else
{
const
int
threads
=
256
;
const
int
blocks
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
max_tensor_
.
Resize
({
blocks
});
auto
*
max_data
=
max_tensor_
.
mutable_data
<
int64_t
>
(
TARGET
(
kCUDA
));
VecMaxKernel
<
int64_t
><<<
blocks
,
threads
,
threads
*
sizeof
(
int64_t
),
stream
>>>
(
x_data
,
max_data
,
x
->
numel
());
h_max_data
.
resize
(
blocks
);
TargetWrapperCuda
::
MemcpyAsync
(
h_max_data
.
data
(),
max_data
,
sizeof
(
int64_t
)
*
blocks
,
IoDirection
::
DtoH
,
stream
);
TargetWrapperCuda
::
StreamSync
(
stream
);
}
auto
maxlen_iterator
=
std
::
max_element
(
h_max_data
.
begin
(),
h_max_data
.
end
());
maxlen
=
h_max_data
[
std
::
distance
(
h_max_data
.
begin
(),
maxlen_iterator
)];
}
auto
y_dim
=
x
->
dims
().
Vectorize
();
...
...
lite/kernels/cuda/sequence_mask_compute.h
浏览文件 @
192be07b
...
...
@@ -28,6 +28,9 @@ class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> {
void
Run
()
override
;
virtual
~
SequenceMaskCompute
()
=
default
;
private:
lite
::
Tensor
max_tensor_
;
};
}
// namespace cuda
...
...
lite/operators/fc_op.cc
浏览文件 @
192be07b
...
...
@@ -50,9 +50,15 @@ bool FcOpLite::CheckShape() const {
bool
FcOpLite
::
InferShapeImpl
()
const
{
const
auto
&
input_dims
=
param_
.
input
->
dims
();
int64_t
w_dims_1
;
if
(
param_
.
w_dims
.
empty
())
{
const
auto
&
w_dims
=
param_
.
w
->
dims
();
w_dims_1
=
param_
.
padding_weights
?
w_dims
[
1
]
-
4
:
w_dims
[
1
];
}
else
{
const
auto
&
w_dims
=
param_
.
w_dims
;
w_dims_1
=
param_
.
padding_weights
?
w_dims
[
1
]
-
4
:
w_dims
[
1
];
}
int
in_num_col_dims
=
param_
.
in_num_col_dims
;
int64_t
w_dims_1
=
param_
.
padding_weights
?
w_dims
[
1
]
-
4
:
w_dims
[
1
];
// Set output dims
std
::
vector
<
DDim
::
value_type
>
output_dims
(
in_num_col_dims
+
1
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录