Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
796e2a57
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
337
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
796e2a57
编写于
7月 16, 2020
作者:
W
Wilber
提交者:
GitHub
7月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CUDA] [Kernels] Add gru fp16 cuda kernel. (#3956)
上级
14397ca0
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
576 addition
and
14 deletion
+576
-14
lite/backends/cuda/math/bias.cu
lite/backends/cuda/math/bias.cu
+12
-0
lite/backends/cuda/math/gru_forward.cu
lite/backends/cuda/math/gru_forward.cu
+141
-0
lite/backends/cuda/math/gru_forward.h
lite/backends/cuda/math/gru_forward.h
+163
-0
lite/backends/cuda/math/sequence2batch.cu
lite/backends/cuda/math/sequence2batch.cu
+5
-0
lite/backends/cuda/math/sequence2batch.h
lite/backends/cuda/math/sequence2batch.h
+37
-0
lite/kernels/cuda/gru_compute.cu
lite/kernels/cuda/gru_compute.cu
+175
-14
lite/kernels/cuda/gru_compute_test.cc
lite/kernels/cuda/gru_compute_test.cc
+43
-0
未找到文件。
lite/backends/cuda/math/bias.cu
浏览文件 @
796e2a57
...
@@ -31,6 +31,17 @@ __global__ void RowwiseAddKernel(
...
@@ -31,6 +31,17 @@ __global__ void RowwiseAddKernel(
c
[
i
]
=
a
[
i
]
+
b
[
w
];
c
[
i
]
=
a
[
i
]
+
b
[
w
];
}
}
}
}
template
<
>
__global__
void
RowwiseAddKernel
(
const
half
*
a
,
const
half
*
b
,
half
*
c
,
int
width
,
int
num
)
{
CUDA_KERNEL_LOOP
(
i
,
num
)
{
int
h
=
i
/
width
;
int
w
=
i
-
h
*
width
;
c
[
i
]
=
__hadd
(
a
[
i
],
b
[
w
]);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
RowwiseAdd
<
T
>::
operator
()(
const
T
*
input
,
void
RowwiseAdd
<
T
>::
operator
()(
const
T
*
input
,
const
T
*
bias
,
const
T
*
bias
,
...
@@ -44,6 +55,7 @@ void RowwiseAdd<T>::operator()(const T* input,
...
@@ -44,6 +55,7 @@ void RowwiseAdd<T>::operator()(const T* input,
}
}
template
struct
RowwiseAdd
<
float
>;
template
struct
RowwiseAdd
<
float
>;
template
struct
RowwiseAdd
<
half
>;
}
// namespace math
}
// namespace math
}
// namespace cuda
}
// namespace cuda
...
...
lite/backends/cuda/math/gru_forward.cu
浏览文件 @
796e2a57
...
@@ -22,6 +22,10 @@ namespace lite {
...
@@ -22,6 +22,10 @@ namespace lite {
namespace
cuda
{
namespace
cuda
{
namespace
math
{
namespace
math
{
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template
<
typename
T
>
template
<
typename
T
>
__global__
void
GruForwardResetOutput
(
__global__
void
GruForwardResetOutput
(
T
*
gate_value
,
T
*
gate_value
,
...
@@ -33,6 +37,7 @@ __global__ void GruForwardResetOutput(
...
@@ -33,6 +37,7 @@ __global__ void GruForwardResetOutput(
bool
is_batch
)
{
bool
is_batch
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
int
batch_idx
=
0
;
if
(
is_batch
)
{
if
(
is_batch
)
{
batch_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
@@ -44,12 +49,14 @@ __global__ void GruForwardResetOutput(
...
@@ -44,12 +49,14 @@ __global__ void GruForwardResetOutput(
T
reset_out_val
;
T
reset_out_val
;
T
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
T
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
T
reset_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
1
];
T
reset_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
1
];
if
(
prev_output_value
)
{
if
(
prev_output_value
)
{
if
(
is_batch
)
{
if
(
is_batch
)
{
prev_output_value
+=
batch_idx
*
frame_size
;
prev_output_value
+=
batch_idx
*
frame_size
;
}
}
prev_out
=
prev_output_value
[
frame_idx
];
prev_out
=
prev_output_value
[
frame_idx
];
}
}
if
(
active_gate
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
if
(
active_gate
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
update_gate_value
=
Sigmoid
(
update_gate_value
);
update_gate_value
=
Sigmoid
(
update_gate_value
);
reset_gate_value
=
Sigmoid
(
reset_gate_value
);
reset_gate_value
=
Sigmoid
(
reset_gate_value
);
...
@@ -60,12 +67,71 @@ __global__ void GruForwardResetOutput(
...
@@ -60,12 +67,71 @@ __global__ void GruForwardResetOutput(
update_gate_value
=
Tanh
(
update_gate_value
);
update_gate_value
=
Tanh
(
update_gate_value
);
reset_gate_value
=
Tanh
(
reset_gate_value
);
reset_gate_value
=
Tanh
(
reset_gate_value
);
}
}
reset_out_val
=
prev_out
*
reset_gate_value
;
reset_out_val
=
prev_out
*
reset_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
0
]
=
update_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
0
]
=
update_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
1
]
=
reset_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
1
]
=
reset_gate_value
;
reset_output_value
[
frame_idx
]
=
reset_out_val
;
reset_output_value
[
frame_idx
]
=
reset_out_val
;
}
}
template
<
>
__global__
void
GruForwardResetOutput
(
half
*
gate_value
,
half
*
reset_output_value
,
half
*
prev_output_value
,
int
frame_size
,
int
batch_size
,
lite
::
cuda
::
math
::
ActivationType
active_gate
,
bool
is_batch
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
if
(
is_batch
)
{
batch_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch_idx
>=
batch_size
)
return
;
gate_value
+=
batch_idx
*
3
*
frame_size
;
reset_output_value
+=
batch_idx
*
frame_size
;
}
half
prev_out
=
0
;
half
reset_out_val
;
half
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
half
reset_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
1
];
if
(
prev_output_value
)
{
if
(
is_batch
)
{
prev_output_value
+=
batch_idx
*
frame_size
;
}
prev_out
=
prev_output_value
[
frame_idx
];
}
if
(
active_gate
==
ActivationType
::
kSigmoid
)
{
update_gate_value
=
Sigmoid
(
update_gate_value
);
reset_gate_value
=
Sigmoid
(
reset_gate_value
);
}
else
if
(
active_gate
==
ActivationType
::
kReLU
)
{
update_gate_value
=
ReLU
(
update_gate_value
);
reset_gate_value
=
ReLU
(
reset_gate_value
);
}
else
if
(
active_gate
==
ActivationType
::
kTanh
)
{
update_gate_value
=
Tanh
(
update_gate_value
);
reset_gate_value
=
Tanh
(
reset_gate_value
);
}
#if __CUDA_ARCH__ >= 530
reset_out_val
=
__hmul
(
prev_out
,
reset_gate_value
);
#else
reset_out_val
=
__float2half
(
__half2float
(
prev_out
)
*
__half2float
(
reset_gate_value
));
#endif
gate_value
[
frame_idx
+
frame_size
*
0
]
=
update_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
1
]
=
reset_gate_value
;
reset_output_value
[
frame_idx
]
=
reset_out_val
;
}
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template
<
typename
T
>
template
<
typename
T
>
__global__
void
GruForwardFinalOutput
(
__global__
void
GruForwardFinalOutput
(
T
*
gate_value
,
T
*
gate_value
,
...
@@ -87,14 +153,17 @@ __global__ void GruForwardFinalOutput(
...
@@ -87,14 +153,17 @@ __global__ void GruForwardFinalOutput(
gate_value
+=
batch_idx
*
3
*
frame_size
;
gate_value
+=
batch_idx
*
3
*
frame_size
;
output_value
+=
batch_idx
*
frame_size
;
output_value
+=
batch_idx
*
frame_size
;
}
}
T
output
;
T
output
;
T
prev_out
=
0
;
T
prev_out
=
0
;
T
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
T
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
T
state_frame_value
=
gate_value
[
frame_idx
+
frame_size
*
2
];
T
state_frame_value
=
gate_value
[
frame_idx
+
frame_size
*
2
];
if
(
prev_output_value
)
{
if
(
prev_output_value
)
{
if
(
is_batch
)
prev_output_value
+=
batch_idx
*
frame_size
;
if
(
is_batch
)
prev_output_value
+=
batch_idx
*
frame_size
;
prev_out
=
prev_output_value
[
frame_idx
];
prev_out
=
prev_output_value
[
frame_idx
];
}
}
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
state_frame_value
=
Sigmoid
(
state_frame_value
);
state_frame_value
=
Sigmoid
(
state_frame_value
);
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kReLU
)
{
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kReLU
)
{
...
@@ -102,6 +171,7 @@ __global__ void GruForwardFinalOutput(
...
@@ -102,6 +171,7 @@ __global__ void GruForwardFinalOutput(
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kTanh
)
{
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kTanh
)
{
state_frame_value
=
Tanh
(
state_frame_value
);
state_frame_value
=
Tanh
(
state_frame_value
);
}
}
if
(
origin_mode
)
{
if
(
origin_mode
)
{
output
=
update_gate_value
*
prev_out
+
state_frame_value
-
output
=
update_gate_value
*
prev_out
+
state_frame_value
-
update_gate_value
*
state_frame_value
;
update_gate_value
*
state_frame_value
;
...
@@ -109,6 +179,76 @@ __global__ void GruForwardFinalOutput(
...
@@ -109,6 +179,76 @@ __global__ void GruForwardFinalOutput(
output
=
prev_out
-
update_gate_value
*
prev_out
+
output
=
prev_out
-
update_gate_value
*
prev_out
+
update_gate_value
*
state_frame_value
;
update_gate_value
*
state_frame_value
;
}
}
gate_value
[
frame_idx
+
frame_size
*
2
]
=
state_frame_value
;
output_value
[
frame_idx
]
=
output
;
}
template
<
>
__global__
void
GruForwardFinalOutput
(
half
*
gate_value
,
half
*
prev_output_value
,
half
*
output_value
,
int
frame_size
,
int
batch_size
,
lite
::
cuda
::
math
::
ActivationType
active_node
,
bool
origin_mode
,
bool
is_batch
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
if
(
is_batch
)
{
batch_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch_idx
>=
batch_size
)
{
return
;
}
gate_value
+=
batch_idx
*
3
*
frame_size
;
output_value
+=
batch_idx
*
frame_size
;
}
half
output
;
half
prev_out
=
0
;
half
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
half
state_frame_value
=
gate_value
[
frame_idx
+
frame_size
*
2
];
if
(
prev_output_value
)
{
if
(
is_batch
)
prev_output_value
+=
batch_idx
*
frame_size
;
prev_out
=
prev_output_value
[
frame_idx
];
}
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
state_frame_value
=
Sigmoid
(
state_frame_value
);
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kReLU
)
{
state_frame_value
=
ReLU
(
state_frame_value
);
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kTanh
)
{
state_frame_value
=
Tanh
(
state_frame_value
);
}
if
(
origin_mode
)
{
#if __CUDA_ARCH__ >= 530
output
=
__hsub
(
__hadd
(
__hmul
(
update_gate_value
,
prev_out
),
state_frame_value
),
__hmul
(
update_gate_value
,
state_frame_value
));
#else
output
=
__float2half
(
__half2float
(
update_gate_value
)
*
__half2float
(
prev_out
)
+
__half2float
(
state_frame_value
)
-
__half2float
(
update_gate_value
)
*
__half2float
(
state_frame_value
));
#endif
}
else
{
#if __CUDA_ARCH__ >= 530
output
=
prev_out
-
update_gate_value
*
prev_out
+
update_gate_value
*
state_frame_value
;
output
=
__hadd
(
__hsub
(
prev_out
,
__hmul
(
update_gate_value
,
prev_out
)),
__hmul
(
update_gate_value
,
state_frame_value
));
#else
output
=
__float2half
(
__half2float
(
prev_out
)
-
__half2float
(
update_gate_value
)
*
__half2float
(
prev_out
)
+
__half2float
(
update_gate_value
)
*
__half2float
(
state_frame_value
));
#endif
}
gate_value
[
frame_idx
+
frame_size
*
2
]
=
state_frame_value
;
gate_value
[
frame_idx
+
frame_size
*
2
]
=
state_frame_value
;
output_value
[
frame_idx
]
=
output
;
output_value
[
frame_idx
]
=
output
;
}
}
...
@@ -122,6 +262,7 @@ template __global__ void GruForwardFinalOutput<float>(
...
@@ -122,6 +262,7 @@ template __global__ void GruForwardFinalOutput<float>(
lite
::
cuda
::
math
::
ActivationType
active_node
,
lite
::
cuda
::
math
::
ActivationType
active_node
,
bool
origin_mode
,
bool
origin_mode
,
bool
is_batch
);
bool
is_batch
);
template
__global__
void
GruForwardResetOutput
<
float
>(
template
__global__
void
GruForwardResetOutput
<
float
>(
float
*
gate_value
,
float
*
gate_value
,
float
*
reset_output_value
,
float
*
reset_output_value
,
...
...
lite/backends/cuda/math/gru_forward.h
浏览文件 @
796e2a57
...
@@ -34,10 +34,32 @@ template <typename Dtype>
...
@@ -34,10 +34,32 @@ template <typename Dtype>
inline
__device__
Dtype
Sigmoid
(
const
Dtype
a
)
{
inline
__device__
Dtype
Sigmoid
(
const
Dtype
a
)
{
return
static_cast
<
Dtype
>
(
1.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
-
a
));
return
static_cast
<
Dtype
>
(
1.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
-
a
));
}
}
template
<
>
inline
__device__
half
Sigmoid
(
const
half
a
)
{
#if __CUDA_ARCH__ >= 530
const
half
tmp
=
__float2half
(
1.0
f
);
return
__hdiv
(
tmp
,
__hadd
(
tmp
,
hexp
(
__hmul
(
__float2half
(
-
1.
f
),
a
))));
#else
return
__float2half
(
1.0
f
/
(
expf
(
__half2float
(
a
)
*
-
1
)
+
1.0
f
));
#endif
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
inline
__device__
Dtype
ReLU
(
const
Dtype
a
)
{
inline
__device__
Dtype
ReLU
(
const
Dtype
a
)
{
return
a
>
static_cast
<
Dtype
>
(
0.
f
)
?
a
:
static_cast
<
Dtype
>
(
0.
f
);
return
a
>
static_cast
<
Dtype
>
(
0.
f
)
?
a
:
static_cast
<
Dtype
>
(
0.
f
);
}
}
template
<
>
inline
__device__
half
ReLU
(
const
half
a
)
{
const
half
tmp
=
__float2half
(
0.
f
);
#if __CUDA_ARCH__ >= 530
return
__hgt
(
a
,
tmp
)
?
a
:
tmp
;
#else
return
__float2half
(
__half2float
(
a
)
>
0.
f
?
__half2float
(
a
)
:
0.
f
);
#endif
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
inline
__device__
Dtype
Tanh
(
const
Dtype
a
)
{
inline
__device__
Dtype
Tanh
(
const
Dtype
a
)
{
Dtype
tmp
=
static_cast
<
Dtype
>
(
-
2.0
)
*
a
;
Dtype
tmp
=
static_cast
<
Dtype
>
(
-
2.0
)
*
a
;
...
@@ -45,6 +67,18 @@ inline __device__ Dtype Tanh(const Dtype a) {
...
@@ -45,6 +67,18 @@ inline __device__ Dtype Tanh(const Dtype a) {
static_cast
<
Dtype
>
(
1.0
);
static_cast
<
Dtype
>
(
1.0
);
}
}
template
<
>
inline
__device__
half
Tanh
(
const
half
a
)
{
#if __CUDA_ARCH__ >= 530
half
tmp
=
__float2half
(
1.0
f
);
half
numerator
=
__hmul
(
__float2half
(
-
2.0
f
),
a
);
return
__hsub
(
__hdiv
(
__float2half
(
2.0
f
),
__hadd
(
tmp
,
hexp
(
numerator
))),
tmp
);
#else
float
tmp
=
-
2.0
f
*
__half2float
(
a
);
return
__float2half
(
2.0
f
/
(
1.0
f
+
expf
(
tmp
))
-
1.0
f
);
#endif
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
GruForwardResetOutput
(
__global__
void
GruForwardResetOutput
(
T
*
gate_value
,
T
*
gate_value
,
...
@@ -54,6 +88,7 @@ __global__ void GruForwardResetOutput(
...
@@ -54,6 +88,7 @@ __global__ void GruForwardResetOutput(
int
batch_size
,
int
batch_size
,
lite
::
cuda
::
math
::
ActivationType
active_gate
,
lite
::
cuda
::
math
::
ActivationType
active_gate
,
bool
is_batch
);
bool
is_batch
);
template
<
typename
T
>
template
<
typename
T
>
__global__
void
GruForwardFinalOutput
(
__global__
void
GruForwardFinalOutput
(
T
*
gate_value
,
T
*
gate_value
,
...
@@ -65,6 +100,134 @@ __global__ void GruForwardFinalOutput(
...
@@ -65,6 +100,134 @@ __global__ void GruForwardFinalOutput(
bool
origin_mode
,
bool
origin_mode
,
bool
is_batch
);
bool
is_batch
);
/*
* threads(tile_size, 1)
* grids(frame_blocks, 1)
*/
template
<
class
T
,
int
TiledSize
>
__global__
void
FastCollectiveGruGate
(
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_weight
,
T
*
reset_output
,
int
frame_size
,
ActivationType
active_node
)
{
T
xt_0
=
0.0
f
;
T
a0
=
0.0
f
;
T
c0
=
0.0
f
;
T
b0
[
TiledSize
];
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tiled_mask
=
((
1
<<
TiledSize
)
-
1
);
// tiled matrix multiply using register shift, faster than sm.
if
(
prev_output_value
)
{
for
(
int
k
=
0
;
k
<
(((
frame_size
-
1
)
/
TiledSize
)
+
1
);
++
k
)
{
a0
=
0
;
if
((
threadIdx
.
x
+
k
*
TiledSize
)
<
frame_size
)
{
a0
=
prev_output_value
[
threadIdx
.
x
+
(
k
*
TiledSize
)];
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
if
(
col
<
frame_size
*
2
&&
(
i
+
k
*
TiledSize
)
<
frame_size
)
{
b0
[
i
]
=
gate_weight
[(
i
+
k
*
TiledSize
)
*
frame_size
*
2
+
col
];
}
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0
=
c0
+
__shfl_sync
(
tiled_mask
,
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#else
c0
=
c0
+
__shfl
(
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#endif
}
}
}
__syncthreads
();
if
(
col
<
frame_size
*
2
)
{
xt_0
=
gate_value
[
col
];
c0
+=
xt_0
;
if
(
active_node
==
ActivationType
::
kSigmoid
)
{
c0
=
Sigmoid
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kReLU
)
{
c0
=
ReLU
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kTanh
)
{
c0
=
Tanh
(
c0
);
}
gate_value
[
col
]
=
c0
;
if
(
frame_size
<=
col
&&
col
<
frame_size
*
2
)
{
T
htp_0
=
0.0
;
if
(
prev_output_value
)
{
htp_0
=
prev_output_value
[
col
-
frame_size
];
}
reset_output
[
col
-
frame_size
]
=
c0
*
htp_0
;
}
else
if
(
col
<
frame_size
)
{
gate_value
[
col
]
=
c0
;
}
}
}
template
<
class
T
,
int
TiledSize
>
__global__
void
FastCollectiveGruOut
(
T
*
gate_weight
,
T
*
prev_out_value
,
T
*
output_value
,
T
*
gate_value
,
T
*
reset_value
,
int
frame_size
,
ActivationType
active_node
,
bool
origin_mode
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T
a0
=
0.0
f
;
T
b0
[
TiledSize
];
T
c0
=
0.0
f
;
int
tiled_mask
=
((
1
<<
TiledSize
)
-
1
);
if
(
prev_out_value
)
{
for
(
int
k
=
0
;
k
<
((
frame_size
-
1
)
/
TiledSize
+
1
);
++
k
)
{
a0
=
0
;
if
((
threadIdx
.
x
+
k
*
TiledSize
)
<
frame_size
)
{
a0
=
reset_value
[
threadIdx
.
x
+
k
*
TiledSize
];
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
if
(
col
<
frame_size
&&
(
i
+
k
*
TiledSize
)
<
frame_size
)
{
b0
[
i
]
=
gate_weight
[(
i
+
k
*
TiledSize
)
*
frame_size
+
col
];
}
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0
=
c0
+
__shfl_sync
(
tiled_mask
,
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#else
c0
=
c0
+
__shfl
(
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#endif
}
}
}
__syncthreads
();
if
(
col
<
frame_size
)
{
T
xt_0
=
gate_value
[
col
+
2
*
frame_size
];
T
gta_0
=
gate_value
[
col
];
T
htp_0
=
0
;
if
(
prev_out_value
)
{
htp_0
=
prev_out_value
[
col
];
}
c0
+=
xt_0
;
if
(
active_node
==
ActivationType
::
kSigmoid
)
{
c0
=
Sigmoid
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kReLU
)
{
c0
=
ReLU
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kTanh
)
{
c0
=
Tanh
(
c0
);
}
gate_value
[
col
+
2
*
frame_size
]
=
c0
;
if
(
origin_mode
)
{
output_value
[
col
]
=
htp_0
*
gta_0
+
(
1
-
gta_0
)
*
c0
;
}
else
{
output_value
[
col
]
=
c0
*
gta_0
+
(
1
-
gta_0
)
*
htp_0
;
}
}
}
}
// namespace math
}
// namespace math
}
// namespace cuda
}
// namespace cuda
}
// namespace lite
}
// namespace lite
...
...
lite/backends/cuda/math/sequence2batch.cu
浏览文件 @
796e2a57
...
@@ -77,8 +77,13 @@ void CopyMatrixRowsFunctor<T>::operator()(
...
@@ -77,8 +77,13 @@ void CopyMatrixRowsFunctor<T>::operator()(
}
}
template
class
CopyMatrixRowsFunctor
<
float
>;
template
class
CopyMatrixRowsFunctor
<
float
>;
template
class
CopyMatrixRowsFunctor
<
half
>;
template
class
LoDTensor2BatchFunctor
<
float
>;
template
class
LoDTensor2BatchFunctor
<
float
>;
template
class
LoDTensor2BatchFunctor
<
half
>;
template
class
Batch2LoDTensorFunctor
<
float
>;
template
class
Batch2LoDTensorFunctor
<
float
>;
template
class
Batch2LoDTensorFunctor
<
half
>;
}
// namespace math
}
// namespace math
}
// namespace cuda
}
// namespace cuda
...
...
lite/backends/cuda/math/sequence2batch.h
浏览文件 @
796e2a57
...
@@ -32,6 +32,9 @@ namespace math {
...
@@ -32,6 +32,9 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
class
CopyMatrixRowsFunctor
{
class
CopyMatrixRowsFunctor
{
public:
public:
// If is_src_index is true, copy the indexed rows of input src to the output
// dst. If is_src_index is false, copy the input src to the indexed of output
// dst. The indexes rows are based on the input index.
void
operator
()(
const
lite
::
Tensor
&
src
,
void
operator
()(
const
lite
::
Tensor
&
src
,
lite
::
Tensor
*
dst
,
lite
::
Tensor
*
dst
,
const
std
::
vector
<
uint64_t
>&
index_lod
,
const
std
::
vector
<
uint64_t
>&
index_lod
,
...
@@ -44,6 +47,11 @@ class CopyMatrixRowsFunctor {
...
@@ -44,6 +47,11 @@ class CopyMatrixRowsFunctor {
template
<
typename
T
>
template
<
typename
T
>
class
LoDTensor2BatchFunctor
{
class
LoDTensor2BatchFunctor
{
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
struct
SeqInfo
{
struct
SeqInfo
{
SeqInfo
(
size_t
start
,
size_t
length
,
size_t
seq_idx
)
SeqInfo
(
size_t
start
,
size_t
length
,
size_t
seq_idx
)
:
start_
(
start
),
length_
(
length
),
seq_idx_
(
seq_idx
)
{}
:
start_
(
start
),
length_
(
length
),
seq_idx_
(
seq_idx
)
{}
...
@@ -60,21 +68,49 @@ class LoDTensor2BatchFunctor {
...
@@ -60,21 +68,49 @@ class LoDTensor2BatchFunctor {
auto
lods
=
lod_tensor
.
lod
();
auto
lods
=
lod_tensor
.
lod
();
CHECK_EQ
(
lods
.
size
(),
1UL
)
<<
"Only support one level sequence now."
;
CHECK_EQ
(
lods
.
size
(),
1UL
)
<<
"Only support one level sequence now."
;
const
auto
&
lod
=
lods
[
0
];
const
auto
&
lod
=
lods
[
0
];
std
::
vector
<
SeqInfo
>
seq_info
;
std
::
vector
<
SeqInfo
>
seq_info
;
for
(
int
seq_id
=
0
;
seq_id
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
seq_id
)
{
for
(
int
seq_id
=
0
;
seq_id
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
seq_id
)
{
size_t
length
=
lod
[
seq_id
+
1
]
-
lod
[
seq_id
];
size_t
length
=
lod
[
seq_id
+
1
]
-
lod
[
seq_id
];
seq_info
.
emplace_back
(
lod
[
seq_id
],
length
,
seq_id
);
seq_info
.
emplace_back
(
lod
[
seq_id
],
length
,
seq_id
);
}
}
std
::
sort
(
seq_info
.
begin
(),
seq_info
.
end
(),
[](
SeqInfo
a
,
SeqInfo
b
)
{
std
::
sort
(
seq_info
.
begin
(),
seq_info
.
end
(),
[](
SeqInfo
a
,
SeqInfo
b
)
{
return
a
.
length_
>
b
.
length_
;
return
a
.
length_
>
b
.
length_
;
});
});
// Calculate the start position of each batch.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// max_seqlen = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = 0
// batch_start_positions[1] = len(b0)
// batch_start_positions[2] = len(b0) + len(b1)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
LoD
batch_lods
;
LoD
batch_lods
;
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
// batch_lods[0] is the start positions for batch LoDTensor
size_t
max_seqlen
=
seq_info
[
0
].
length_
;
size_t
max_seqlen
=
seq_info
[
0
].
length_
;
batch_lods
[
0
].
resize
(
max_seqlen
+
1
);
batch_lods
[
0
].
resize
(
max_seqlen
+
1
);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods
[
2
].
resize
(
seq_info
.
size
());
batch_lods
[
2
].
resize
(
seq_info
.
size
());
auto
*
batch_starts
=
batch_lods
[
0
].
data
();
auto
*
batch_starts
=
batch_lods
[
0
].
data
();
...
@@ -101,6 +137,7 @@ class LoDTensor2BatchFunctor {
...
@@ -101,6 +137,7 @@ class LoDTensor2BatchFunctor {
}
}
batch_tensor
->
set_lod
(
batch_lods
);
batch_tensor
->
set_lod
(
batch_lods
);
lite
::
cuda
::
math
::
CopyMatrixRowsFunctor
<
T
>
to_batch
;
lite
::
cuda
::
math
::
CopyMatrixRowsFunctor
<
T
>
to_batch
;
to_batch
(
lod_tensor
,
batch_tensor
,
batch_lods
[
1
],
true
,
stream
);
to_batch
(
lod_tensor
,
batch_tensor
,
batch_lods
[
1
],
true
,
stream
);
CUDA_POST_KERNEL_CHECK
;
CUDA_POST_KERNEL_CHECK
;
...
...
lite/kernels/cuda/gru_compute.cu
浏览文件 @
796e2a57
...
@@ -48,10 +48,69 @@ struct GRUUnitFunctor {
...
@@ -48,10 +48,69 @@ struct GRUUnitFunctor {
CUDAContext
*
context
)
{
CUDAContext
*
context
)
{
dim3
threads
,
grids
;
dim3
threads
,
grids
;
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
if
(
lite
::
TargetWrapperCuda
::
GetComputeCapability
()
>=
70
)
{
int
frame_blocks
=
(
frame_size
+
1024
-
1
)
/
1024
;
if
(
frame_size
<
16
)
{
threads
=
dim3
(
frame_per_block
,
1
);
constexpr
int
tiled_size
=
8
;
grids
=
dim3
(
frame_blocks
,
1
);
int
frame_blocks
=
(
frame_size
*
2
+
tiled_size
-
1
)
/
tiled_size
;
threads
=
dim3
(
tiled_size
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruGate
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
prev_out_value
,
value
.
gate_weight
,
value
.
reset_output_value
,
frame_size
,
active_gate
);
frame_blocks
=
(
frame_size
+
tiled_size
-
1
)
/
tiled_size
;
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruOut
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
state_weight
,
value
.
prev_out_value
,
value
.
output_value
,
value
.
gate_value
,
value
.
reset_output_value
,
frame_size
,
active_node
,
origin_mode
);
}
else
{
constexpr
int
tiled_size
=
16
;
int
frame_blocks
=
(
frame_size
*
2
+
tiled_size
-
1
)
/
tiled_size
;
threads
=
dim3
(
tiled_size
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruGate
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
prev_out_value
,
value
.
gate_weight
,
value
.
reset_output_value
,
frame_size
,
active_gate
);
frame_blocks
=
(
frame_size
+
tiled_size
-
1
)
/
tiled_size
;
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruOut
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
state_weight
,
value
.
prev_out_value
,
value
.
output_value
,
value
.
gate_value
,
value
.
reset_output_value
,
frame_size
,
active_node
,
origin_mode
);
}
return
;
}
else
{
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
int
frame_blocks
=
(
frame_size
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame_per_block
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
}
}
else
{
}
else
{
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grids
=
dim3
((
frame_size
+
32
-
1
)
/
32
,
(
batch_size
+
32
-
1
)
/
32
);
grids
=
dim3
((
frame_size
+
32
-
1
)
/
32
,
(
batch_size
+
32
-
1
)
/
32
);
...
@@ -121,6 +180,90 @@ struct GRUUnitFunctor {
...
@@ -121,6 +180,90 @@ struct GRUUnitFunctor {
template
struct
GRUUnitFunctor
<
float
>;
template
struct
GRUUnitFunctor
<
float
>;
template
<
>
struct
GRUUnitFunctor
<
half
>
{
static
void
compute
(
GRUMetaValue
<
half
>
value
,
int
frame_size
,
int
batch_size
,
const
lite
::
cuda
::
math
::
ActivationType
&
active_node
,
const
lite
::
cuda
::
math
::
ActivationType
&
active_gate
,
bool
origin_mode
,
lite
::
cuda
::
math
::
Gemm
<
half
,
half
>*
blas
,
CUDAContext
*
context
)
{
dim3
threads
,
grids
;
if
(
batch_size
==
1
)
{
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
int
frame_blocks
=
(
frame_size
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame_per_block
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
}
else
{
threads
=
dim3
(
32
,
32
);
grids
=
dim3
((
frame_size
+
32
-
1
)
/
32
,
(
batch_size
+
32
-
1
)
/
32
);
}
if
(
value
.
prev_out_value
)
{
CHECK
(
blas
->
init
(
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
frame_size
,
frame_size
*
2
,
frame_size
*
3
,
context
));
blas
->
run
(
1.0
f
,
1.0
f
,
value
.
prev_out_value
,
value
.
gate_weight
,
value
.
gate_value
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardResetOutput
<
half
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
reset_output_value
,
value
.
prev_out_value
,
frame_size
,
batch_size
,
active_gate
,
batch_size
==
1
);
CUDA_POST_KERNEL_CHECK
;
if
(
value
.
prev_out_value
)
{
CHECK
(
blas
->
init
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
frame_size
,
frame_size
,
frame_size
*
3
,
context
));
blas
->
run
(
1.0
f
,
1.0
f
,
value
.
reset_output_value
,
value
.
state_weight
,
value
.
gate_value
+
frame_size
*
2
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardFinalOutput
<
half
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
active_node
,
origin_mode
,
batch_size
==
1
);
CUDA_POST_KERNEL_CHECK
;
}
};
template
<
typename
T
,
PrecisionType
PType
>
template
<
typename
T
,
PrecisionType
PType
>
void
GRUCompute
<
T
,
PType
>::
PrepareForRun
()
{
void
GRUCompute
<
T
,
PType
>::
PrepareForRun
()
{
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>
);
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>
);
...
@@ -141,18 +284,17 @@ void GRUCompute<T, PType>::Run() {
...
@@ -141,18 +284,17 @@ void GRUCompute<T, PType>::Run() {
if
(
param
.
bias
)
{
if
(
param
.
bias
)
{
bias
=
const_cast
<
lite
::
Tensor
*>
(
param
.
bias
);
bias
=
const_cast
<
lite
::
Tensor
*>
(
param
.
bias
);
}
}
auto
*
weight
=
param
.
weight
;
const
lite
::
Tensor
*
weight
=
param
.
weight
;
auto
*
weight_data
=
const_cast
<
T
*>
(
weight
->
template
data
<
T
>());
T
*
weight_data
=
const_cast
<
T
*>
(
weight
->
template
data
<
T
>());
auto
*
batch_gate
=
param
.
batch_gate
;
lite
::
Tensor
*
batch_gate
=
param
.
batch_gate
;
auto
*
batch_reset_hidden_prev
=
param
.
batch_reset_hidden_prev
;
lite
::
Tensor
*
batch_reset_hidden_prev
=
param
.
batch_reset_hidden_prev
;
auto
*
batch_hidden
=
param
.
batch_hidden
;
lite
::
Tensor
*
batch_hidden
=
param
.
batch_hidden
;
auto
*
hidden
=
param
.
hidden
;
lite
::
Tensor
*
hidden
=
param
.
hidden
;
auto
*
batch_reset_hidden_prev_data
=
T
*
batch_reset_hidden_prev_data
=
batch_reset_hidden_prev
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
batch_reset_hidden_prev
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
auto
*
batch_gate_data
=
batch_gate
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
T
*
batch_gate_data
=
batch_gate
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
auto
*
batch_hidden_data
=
T
*
batch_hidden_data
=
batch_hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
batch_hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
bool
is_reverse
=
param
.
is_reverse
;
bool
is_reverse
=
param
.
is_reverse
;
auto
active_node
=
lite
::
cuda
::
math
::
GetActiveType
(
param
.
activation
);
auto
active_node
=
lite
::
cuda
::
math
::
GetActiveType
(
param
.
activation
);
auto
active_gate
=
lite
::
cuda
::
math
::
GetActiveType
(
param
.
gate_activation
);
auto
active_gate
=
lite
::
cuda
::
math
::
GetActiveType
(
param
.
gate_activation
);
...
@@ -224,6 +366,8 @@ void GRUCompute<T, PType>::Run() {
...
@@ -224,6 +366,8 @@ void GRUCompute<T, PType>::Run() {
using
GRUFp32
=
using
GRUFp32
=
paddle
::
lite
::
kernels
::
cuda
::
GRUCompute
<
float
,
PRECISION
(
kFloat
)
>
;
paddle
::
lite
::
kernels
::
cuda
::
GRUCompute
<
float
,
PRECISION
(
kFloat
)
>
;
using
GRUFp16
=
paddle
::
lite
::
kernels
::
cuda
::
GRUCompute
<
half
,
PRECISION
(
kFP16
)
>
;
REGISTER_LITE_KERNEL
(
gru
,
kCUDA
,
kFloat
,
kNCHW
,
GRUFp32
,
def
)
REGISTER_LITE_KERNEL
(
gru
,
kCUDA
,
kFloat
,
kNCHW
,
GRUFp32
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"H0"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"H0"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
...
@@ -234,3 +378,20 @@ REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
...
@@ -234,3 +378,20 @@ REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.
BindOutput
(
"BatchHidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"BatchHidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Hidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Hidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
gru
,
kCUDA
,
kFP16
,
kNCHW
,
GRUFp16
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"H0"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"Weight"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"BatchGate"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"BatchResetHiddenPrev"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"BatchHidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Hidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/kernels/cuda/gru_compute_test.cc
浏览文件 @
796e2a57
...
@@ -45,10 +45,13 @@ class GRUTest : public ::testing::Test {
...
@@ -45,10 +45,13 @@ class GRUTest : public ::testing::Test {
x_ref_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_ref_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_gpu_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_gpu_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_ref_
.
set_lod
(
lod_
);
x_ref_
.
set_lod
(
lod_
);
w_ref_
.
Resize
(
lite
::
DDim
(
w_shape_
));
w_ref_
.
Resize
(
lite
::
DDim
(
w_shape_
));
w_gpu_
.
Resize
(
lite
::
DDim
(
w_shape_
));
w_gpu_
.
Resize
(
lite
::
DDim
(
w_shape_
));
auto
x_ref_data
=
x_ref_
.
mutable_data
<
float
>
();
auto
x_ref_data
=
x_ref_
.
mutable_data
<
float
>
();
auto
w_ref_data
=
w_ref_
.
mutable_data
<
float
>
();
auto
w_ref_data
=
w_ref_
.
mutable_data
<
float
>
();
for
(
int64_t
i
=
0
;
i
<
x_ref_
.
numel
();
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
x_ref_
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
}
...
@@ -63,6 +66,7 @@ class GRUTest : public ::testing::Test {
...
@@ -63,6 +66,7 @@ class GRUTest : public ::testing::Test {
batch_hidden_gpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
batch_hidden_gpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
batch_reset_hidden_gpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
batch_reset_hidden_gpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
RunBaseLine
();
RunBaseLine
();
InitParamAndContext
();
InitParamAndContext
();
}
}
...
@@ -91,6 +95,22 @@ class GRUTest : public ::testing::Test {
...
@@ -91,6 +95,22 @@ class GRUTest : public ::testing::Test {
w_gpu_
.
dims
());
w_gpu_
.
dims
());
}
}
void
InitHalfInput
()
{
x_half_
.
Resize
(
lite
::
DDim
(
x_shape_
));
auto
x_half_data
=
x_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
x_half_
.
numel
();
i
++
)
{
x_half_data
[
i
]
=
half
(
lite
::
float16
(
x_ref_
.
data
<
float
>
()[
i
]));
}
x_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half_data
,
x_gpu_
.
dims
());
x_gpu_
.
set_lod
(
x_ref_
.
lod
());
w_half_
.
Resize
(
w_ref_
.
dims
());
auto
w_half_data
=
w_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
w_half_
.
numel
();
i
++
)
{
w_half_data
[
i
]
=
half
(
lite
::
float16
(
w_ref_
.
data
<
float
>
()[
i
]));
}
w_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
w_half_data
,
w_gpu_
.
dims
());
}
void
RunBaseLine
()
{}
void
RunBaseLine
()
{}
int
batch_
,
frame_size_
;
int
batch_
,
frame_size_
;
...
@@ -134,6 +154,29 @@ TEST_F(GRUTest, TestFP32) {
...
@@ -134,6 +154,29 @@ TEST_F(GRUTest, TestFP32) {
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
}
}
TEST_F
(
GRUTest
,
TestFP16
)
{
InitHalfInput
();
GRUCompute
<
half
,
PRECISION
(
kFP16
)
>
kernel
;
kernel
.
SetParam
(
param_
);
kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp16, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
}
}
// namespace cuda
}
// namespace cuda
}
// namespace kernels
}
// namespace kernels
}
// namespace lite
}
// namespace lite
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录