Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0059404e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0059404e
编写于
11月 05, 2019
作者:
Z
zhaoyuchen2018
提交者:
GitHub
11月 05, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix ce ocr_recognition test fails (#20987)
ocr_recognition fails, so add a path to handle small frame_size. test=develop
上级
f56967c4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
39 addition
and
30 deletion
+39
-30
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
+4
-7
paddle/fluid/operators/math/gru_compute.cu
paddle/fluid/operators/math/gru_compute.cu
+35
-23
未找到文件。
paddle/fluid/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
0059404e
...
@@ -105,7 +105,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
...
@@ -105,7 +105,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
* threads(tile_size, 1)
* threads(tile_size, 1)
* grid(frame_blocks, 1)
* grid(frame_blocks, 1)
*/
*/
template
<
class
T
>
template
<
class
T
,
int
Tiled_size
>
__global__
void
KeFastCollectiveGruGate
(
T
*
gate_value
,
T
*
prev_output_value
,
__global__
void
KeFastCollectiveGruGate
(
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_weight
,
T
*
reset_output
,
T
*
gate_weight
,
T
*
reset_output
,
int
frame_size
,
int
frame_size
,
...
@@ -113,9 +113,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
...
@@ -113,9 +113,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
T
xt_0
=
0.0
f
;
T
xt_0
=
0.0
f
;
T
a0
=
0.0
f
;
T
a0
=
0.0
f
;
T
c0
=
0.0
f
;
T
c0
=
0.0
f
;
T
b0
[
Tiled_size
];
int
Tiled_size
=
blockDim
.
x
;
T
b0
[
16
];
int
COL
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
COL
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
Tiled_mask
=
((
1
<<
Tiled_size
)
-
1
);
int
Tiled_mask
=
((
1
<<
Tiled_size
)
-
1
);
...
@@ -165,7 +163,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
...
@@ -165,7 +163,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
* threads(tile_size, 1)
* threads(tile_size, 1)
* grid(frame_blocks, 1)
* grid(frame_blocks, 1)
*/
*/
template
<
class
T
>
template
<
class
T
,
int
Tiled_size
>
__global__
void
KeFastCollectiveGruOut
(
T
*
gate_weight
,
T
*
prev_out_value
,
__global__
void
KeFastCollectiveGruOut
(
T
*
gate_weight
,
T
*
prev_out_value
,
T
*
output_value
,
T
*
gate_value
,
T
*
output_value
,
T
*
gate_value
,
T
*
reset_value
,
int
frame_size
,
T
*
reset_value
,
int
frame_size
,
...
@@ -174,10 +172,9 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
...
@@ -174,10 +172,9 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
int
COL
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
COL
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T
a0
=
0.0
f
;
T
a0
=
0.0
f
;
T
b0
[
16
];
T
b0
[
Tiled_size
];
T
c0
=
0.0
f
;
T
c0
=
0.0
f
;
int
Tiled_size
=
blockDim
.
x
;
int
Tiled_mask
=
((
1
<<
Tiled_size
)
-
1
);
int
Tiled_mask
=
((
1
<<
Tiled_size
)
-
1
);
//- Tiled matrix multiply with register shift
//- Tiled matrix multiply with register shift
if
(
prev_out_value
)
{
if
(
prev_out_value
)
{
...
...
paddle/fluid/operators/math/gru_compute.cu
浏览文件 @
0059404e
...
@@ -31,29 +31,41 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
...
@@ -31,29 +31,41 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
dim3
grid
;
dim3
grid
;
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
if
(
context
.
GetComputeCapability
()
>=
70
)
{
if
(
context
.
GetComputeCapability
()
>=
70
)
{
auto
ComputeTiledSize
=
[](
int
frame_size
)
{
if
(
frame_size
<
16
)
{
if
(
frame_size
>=
16
)
constexpr
int
tiled_size
=
8
;
return
16
;
int
frame_blocks
=
(
frame_size
*
2
+
tiled_size
-
1
)
/
tiled_size
;
else
if
(
frame_size
<
16
)
threads
=
dim3
(
tiled_size
,
1
);
return
8
;
grid
=
dim3
(
frame_blocks
,
1
);
};
detail
::
KeFastCollectiveGruGate
<
T
,
tiled_size
><<<
grid
,
threads
,
0
,
stream
>>>
(
auto
tiled_size
=
ComputeTiledSize
(
frame_size
);
value
.
gate_value
,
value
.
prev_out_value
,
value
.
gate_weight
,
int
frame_blocks
=
(
frame_size
*
2
+
tiled_size
-
1
)
/
tiled_size
;
value
.
reset_output_value
,
frame_size
,
active_gate
);
threads
=
dim3
(
tiled_size
,
1
);
grid
=
dim3
(
frame_blocks
,
1
);
frame_blocks
=
(
frame_size
+
tiled_size
-
1
)
/
tiled_size
;
grid
=
dim3
(
frame_blocks
,
1
);
detail
::
KeFastCollectiveGruGate
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
KeFastCollectiveGruOut
<
value
.
gate_value
,
value
.
prev_out_value
,
value
.
gate_weight
,
T
,
tiled_size
><<<
grid
,
threads
,
0
,
stream
>>>
(
value
.
reset_output_value
,
frame_size
,
active_gate
);
value
.
state_weight
,
value
.
prev_out_value
,
value
.
output_value
,
value
.
gate_value
,
value
.
reset_output_value
,
frame_size
,
frame_blocks
=
(
frame_size
+
tiled_size
-
1
)
/
tiled_size
;
active_node
,
origin_mode
);
grid
=
dim3
(
frame_blocks
,
1
);
}
else
{
detail
::
KeFastCollectiveGruOut
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
constexpr
int
tiled_size
=
16
;
value
.
state_weight
,
value
.
prev_out_value
,
value
.
output_value
,
int
frame_blocks
=
(
frame_size
*
2
+
tiled_size
-
1
)
/
tiled_size
;
value
.
gate_value
,
value
.
reset_output_value
,
frame_size
,
active_node
,
threads
=
dim3
(
tiled_size
,
1
);
origin_mode
);
grid
=
dim3
(
frame_blocks
,
1
);
detail
::
KeFastCollectiveGruGate
<
T
,
tiled_size
><<<
grid
,
threads
,
0
,
stream
>>>
(
value
.
gate_value
,
value
.
prev_out_value
,
value
.
gate_weight
,
value
.
reset_output_value
,
frame_size
,
active_gate
);
frame_blocks
=
(
frame_size
+
tiled_size
-
1
)
/
tiled_size
;
grid
=
dim3
(
frame_blocks
,
1
);
detail
::
KeFastCollectiveGruOut
<
T
,
tiled_size
><<<
grid
,
threads
,
0
,
stream
>>>
(
value
.
state_weight
,
value
.
prev_out_value
,
value
.
output_value
,
value
.
gate_value
,
value
.
reset_output_value
,
frame_size
,
active_node
,
origin_mode
);
}
return
;
return
;
}
else
{
}
else
{
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录