Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
b65a6dc9
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b65a6dc9
编写于
12月 15, 2019
作者:
W
Wilber
提交者:
GitHub
12月 15, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize search_grnn test=develop (#2608)
optimize search_grnn
上级
1dbcd51d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
593 addition
and
324 deletion
+593
-324
lite/backends/cuda/math/transpose.cu
lite/backends/cuda/math/transpose.cu
+59
-54
lite/backends/cuda/math/transpose.h
lite/backends/cuda/math/transpose.h
+19
-9
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+1
-1
lite/kernels/cuda/layout_compute.cc
lite/kernels/cuda/layout_compute.cc
+5
-3
lite/kernels/cuda/layout_compute.h
lite/kernels/cuda/layout_compute.h
+13
-0
lite/kernels/cuda/search_grnn_compute.cu
lite/kernels/cuda/search_grnn_compute.cu
+423
-247
lite/kernels/cuda/search_grnn_compute.h
lite/kernels/cuda/search_grnn_compute.h
+67
-3
lite/kernels/cuda/transpose_compute.cu
lite/kernels/cuda/transpose_compute.cu
+4
-5
lite/kernels/cuda/transpose_compute.h
lite/kernels/cuda/transpose_compute.h
+1
-1
lite/kernels/cuda/transpose_compute_test.cc
lite/kernels/cuda/transpose_compute_test.cc
+1
-1
未找到文件。
lite/backends/cuda/math/transpose.cu
浏览文件 @
b65a6dc9
...
@@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N,
...
@@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N,
const
int
W
,
const
int
W
,
const
T
*
input
,
const
T
*
input
,
T
*
out
,
T
*
out
,
CUDAContext
*
ctx
)
{
cudaStream_t
*
stream
)
{
const
int
dh
=
(
H
+
kTileDim
-
1
)
/
kTileDim
;
const
int
dh
=
(
H
+
kTileDim
-
1
)
/
kTileDim
;
const
int
dw
=
(
W
+
kTileDim
-
1
)
/
kTileDim
;
const
int
dw
=
(
W
+
kTileDim
-
1
)
/
kTileDim
;
BatchTranspose2DCUDAKernel
<
BatchTranspose2DCUDAKernel
<
T
><<<
N
*
dh
*
dw
,
dim3
(
kTileDim
,
kBlockRows
),
0
,
ctx
->
exec_stream
()
>>>
(
T
><<<
N
*
dh
*
dw
,
dim3
(
kTileDim
,
kBlockRows
),
0
,
*
stream
>>>
(
N
,
H
,
W
,
dh
,
dw
,
input
,
out
);
N
,
H
,
W
,
dh
,
dw
,
input
,
out
);
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
}
#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \
template <> \
void NCHW2NHWC<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NCHW2NHWC
(
float
)
TYPE_SPECIALIZED_CUDA_NCHW2NHWC
(
int8_t
)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
template <> \
void NHWC2NCHW<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NHWC2NCHW
(
float
)
TYPE_SPECIALIZED_CUDA_NHWC2NCHW
(
int8_t
)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template
<
typename
T
>
template
<
typename
T
>
__global__
void
TransposeCUDAKernel
(
const
int
size
,
__global__
void
TransposeCUDAKernel
(
const
int
size
,
const
int
ndim
,
const
int
ndim
,
...
@@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
...
@@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
const
std
::
vector
<
int
>&
axes
,
const
std
::
vector
<
int
>&
axes
,
const
T
*
X
,
const
T
*
X
,
T
*
Y
,
T
*
Y
,
CUDAContext
*
ctx
)
{
lite
::
Tensor
*
Y_dims_
,
lite
::
Tensor
*
strides_
,
cudaStream_t
*
stream
)
{
CHECK_EQ
(
X_dims
.
size
(),
axes
.
size
())
<<
"dimension size should be equal"
;
CHECK_EQ
(
X_dims
.
size
(),
axes
.
size
())
<<
"dimension size should be equal"
;
int
ndim
=
X_dims
.
size
();
int
ndim
=
X_dims
.
size
();
std
::
vector
<
int
>
strides
(
ndim
,
0
);
std
::
vector
<
int
>
strides
(
ndim
,
0
);
...
@@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
...
@@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
size
*=
X_dims
[
i
];
size
*=
X_dims
[
i
];
}
}
lite
::
Tensor
Y_dims_
,
strides_
;
Y_dims_
->
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
Y_dims_
.
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_y_dims
=
Y_dims_
->
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
int
*
d_y_dims
=
Y_dims_
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
d_y_dims
,
CopySync
<
TARGET
(
kCUDA
)
>
(
Y_dims
.
data
(),
d_y_dims
,
Y_dims
.
data
(),
sizeof
(
int
)
*
Y_dims
.
size
(),
IoDirection
::
HtoD
);
sizeof
(
int
)
*
Y_dims
.
size
(),
IoDirection
::
HtoD
,
*
stream
);
strides_
.
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
strides_
->
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_strides
=
strides_
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
int
*
d_strides
=
strides_
->
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
d_strides
,
TargetWrapperCuda
::
MemcpyAsync
(
d_strides
,
strides
.
data
(),
strides
.
data
(),
sizeof
(
int
)
*
strides
.
size
(),
sizeof
(
int
)
*
strides
.
size
(),
IoDirection
::
HtoD
);
IoDirection
::
HtoD
,
*
stream
);
const
int
M
=
(
size
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
const
int
M
=
(
size
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
TransposeCUDAKernel
<<<
M
,
CUDA_NUM_THREADS
,
0
,
ctx
->
exec_stream
()
>>>
(
TransposeCUDAKernel
<<<
M
,
CUDA_NUM_THREADS
,
0
,
*
stream
>>>
(
size
,
ndim
,
d_strides
,
d_y_dims
,
X
,
Y
);
size
,
ndim
,
d_strides
,
d_y_dims
,
X
,
Y
);
auto
e
=
cudaGetLastError
();
auto
e
=
cudaGetLastError
();
CHECK_EQ
(
e
,
cudaSuccess
)
<<
" CUDA: "
<<
cudaGetErrorString
(
e
);
CHECK_EQ
(
e
,
cudaSuccess
)
<<
" CUDA: "
<<
cudaGetErrorString
(
e
);
}
}
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \
template
<
typename
T
>
template <> \
void
Transpose
<
T
>::
NCHW2NHWC
(
void Transpose<T>(const std::vector<int64_t>& X_dims, \
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
)
{
const std::vector<int>& axes, \
BatchTranspose2DCUDAImpl
<
T
>
(
N
,
C
,
HxW
,
X
,
Y
,
stream
);
const T* X, \
}
T* Y, \
CUDAContext* ctx) { \
template
<
typename
T
>
TransposeCUDAImpl<T>(X_dims, axes, X, Y, ctx); \
void
Transpose
<
T
>::
NHWC2NCHW
(
}
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
)
{
TYPE_SPECIALIZED_CUDA_TRANSPOSE
(
float
)
BatchTranspose2DCUDAImpl
<
T
>
(
N
,
HxW
,
C
,
X
,
Y
,
stream
);
#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF
}
template
<
typename
T
>
void
Transpose
<
T
>::
transpose
(
T
*
dst
,
const
T
*
src
,
const
std
::
vector
<
int64_t
>&
src_dims
,
const
std
::
vector
<
int
>&
axes
,
cudaStream_t
*
stream
)
{
TransposeCUDAImpl
<
T
>
(
src_dims
,
axes
,
src
,
dst
,
&
Y_dims_
,
&
strides_
,
stream
);
}
// template <typename T>
// void Transpose<T>::transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream) {
// std::vector<int64_t> _src_dims(src_dims.size(), 0);
// std::transform(
// src_dims.begin(),
// src_dims.end(),
// _src_dims.begin(),
// [](int data) -> int64_t { return static_cast<int64_t>(data); });
// TransposeCUDAImpl<T>(_src_dims, axes, src, dst, &Y_dims_, &strides_,
// stream);
//}
template
class
Transpose
<
int8_t
>;
template
class
Transpose
<
float
>;
}
// namespace math
}
// namespace math
}
// namespace cuda
}
// namespace cuda
...
...
lite/backends/cuda/math/transpose.h
浏览文件 @
b65a6dc9
...
@@ -26,17 +26,27 @@ namespace cuda {
...
@@ -26,17 +26,27 @@ namespace cuda {
namespace
math
{
namespace
math
{
template
<
typename
T
>
template
<
typename
T
>
void
NCHW2NHWC
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
context
);
class
Transpose
{
public:
void
NCHW2NHWC
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
);
template
<
typename
T
>
void
NHWC2NCHW
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
);
void
NHWC2NCHW
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
context
);
template
<
typename
T
>
void
transpose
(
T
*
dst
,
void
Transpose
(
const
std
::
vector
<
int64_t
>&
X_dims
,
const
T
*
src
,
const
std
::
vector
<
int
>&
axes
,
const
std
::
vector
<
int64_t
>&
src_dims
,
const
T
*
X
,
const
std
::
vector
<
int
>&
axes
,
T
*
Y
,
cudaStream_t
*
stream
);
CUDAContext
*
ctx
);
// void transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream);
private:
lite
::
Tensor
Y_dims_
,
strides_
;
// for transpose.
};
}
// namespace math
}
// namespace math
}
// namespace cuda
}
// namespace cuda
...
...
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
b65a6dc9
...
@@ -26,7 +26,7 @@ add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS
...
@@ -26,7 +26,7 @@ add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS
${
lite_kernel_deps
}
cudnn_pool
)
${
lite_kernel_deps
}
cudnn_pool
)
add_kernel
(
bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS
${
lite_kernel_deps
}
cuda_gemm
)
add_kernel
(
search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS
${
lite_kernel_deps
}
cuda_gemm
${
math_cuda
}
)
add_kernel
(
sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS
${
lite_kernel_deps
}
)
...
...
lite/kernels/cuda/layout_compute.cc
浏览文件 @
b65a6dc9
...
@@ -29,7 +29,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
...
@@ -29,7 +29,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
}
}
std
::
vector
<
int64_t
>
trim_dims
;
std
::
vector
<
int64_t
>
trim_dims
;
trim_dims
.
resize
(
actual_dims_size
);
trim_dims
.
resize
(
actual_dims_size
);
for
(
in
t
i
=
0
;
i
<
actual_dims_size
;
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
actual_dims_size
;
++
i
)
{
trim_dims
[
i
]
=
dims
[
i
];
trim_dims
[
i
]
=
dims
[
i
];
}
}
if
(
trim_dims
.
size
()
==
0
)
{
if
(
trim_dims
.
size
()
==
0
)
{
...
@@ -41,6 +41,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
...
@@ -41,6 +41,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
#define NCHWTONHWC(type) \
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
...
@@ -56,11 +57,12 @@ inline DDim trim_singular_dims(const DDim& dims) {
...
@@ -56,11 +57,12 @@ inline DDim trim_singular_dims(const DDim& dims) {
int w = input_dim[3]; \
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NCHW2NHWC<type>(n, c, h * w, input, output, &ctx
);
trans.NCHW2NHWC(n, c, h* w, input, output, &stream
);
#define NHWCTONCHW(type) \
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
...
@@ -76,7 +78,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
...
@@ -76,7 +78,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
int c = input_dim[3]; \
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NHWC2NCHW<type>(n, c, h * w, input, output, &ctx
);
trans.NHWC2NCHW(n, c, h* w, input, output, &stream
);
void
NCHWToNHWCCompute
::
Run
()
{
NCHWTONHWC
(
float
)
}
void
NCHWToNHWCCompute
::
Run
()
{
NCHWTONHWC
(
float
)
}
...
...
lite/kernels/cuda/layout_compute.h
浏览文件 @
b65a6dc9
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/kernel.h"
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -25,6 +26,9 @@ class NCHWToNHWCCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
...
@@ -25,6 +26,9 @@ class NCHWToNHWCCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using
param_t
=
operators
::
LayoutParam
;
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
void
Run
()
override
;
virtual
~
NCHWToNHWCCompute
()
=
default
;
virtual
~
NCHWToNHWCCompute
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
};
};
class
NCHWToNHWCComputeInt8
class
NCHWToNHWCComputeInt8
...
@@ -33,6 +37,9 @@ class NCHWToNHWCComputeInt8
...
@@ -33,6 +37,9 @@ class NCHWToNHWCComputeInt8
using
param_t
=
operators
::
LayoutParam
;
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
void
Run
()
override
;
virtual
~
NCHWToNHWCComputeInt8
()
=
default
;
virtual
~
NCHWToNHWCComputeInt8
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
int8_t
>
trans
;
};
};
class
NHWCToNCHWCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
class
NHWCToNCHWCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
...
@@ -40,6 +47,9 @@ class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
...
@@ -40,6 +47,9 @@ class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using
param_t
=
operators
::
LayoutParam
;
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
void
Run
()
override
;
virtual
~
NHWCToNCHWCompute
()
=
default
;
virtual
~
NHWCToNCHWCompute
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
};
};
class
NHWCToNCHWComputeInt8
class
NHWCToNCHWComputeInt8
...
@@ -48,6 +58,9 @@ class NHWCToNCHWComputeInt8
...
@@ -48,6 +58,9 @@ class NHWCToNCHWComputeInt8
using
param_t
=
operators
::
LayoutParam
;
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
void
Run
()
override
;
virtual
~
NHWCToNCHWComputeInt8
()
=
default
;
virtual
~
NHWCToNCHWComputeInt8
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
int8_t
>
trans
;
};
};
}
// namespace cuda
}
// namespace cuda
...
...
lite/kernels/cuda/search_grnn_compute.cu
浏览文件 @
b65a6dc9
...
@@ -12,6 +12,7 @@ limitations under the License. */
...
@@ -12,6 +12,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/op_registry.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_grnn_compute.h"
#include "lite/kernels/cuda/search_grnn_compute.h"
...
@@ -19,294 +20,469 @@ namespace paddle {
...
@@ -19,294 +20,469 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
kernels
{
namespace
kernels
{
namespace
cuda
{
namespace
cuda
{
using
Tensor
=
lite
::
Tensor
;
using
Tensor
=
lite
::
Tensor
;
template
<
typename
T
>
template
<
typename
Dtype
>
T
sigmoid
(
T
z
)
{
__global__
void
trans_map2out
(
return
1
/
(
1
+
std
::
exp
(
-
z
));
Dtype
*
output
,
const
Dtype
*
input
,
const
int
*
map
,
int
count
,
int
lastdim
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
count
)
{
int
seq
=
tid
/
lastdim
;
output
[
map
[
seq
]
*
lastdim
+
tid
%
lastdim
]
=
input
[
tid
];
}
}
}
template
<
typename
T
>
template
<
typename
Dtype
>
__global__
void
PreComputeKernel
(
__global__
void
trans_map2in
(
const
int
num
,
const
T
*
w_x_e
,
const
T
*
wz_x_e
,
T
*
tilde
,
T
*
z
,
T
*
hidden
)
{
Dtype
*
output
,
const
Dtype
*
input
,
const
int
*
map
,
int
count
,
int
lastdim
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
num
)
{
if
(
tid
<
count
)
{
tilde
[
index
]
=
std
::
tanh
(
w_x_e
[
index
]);
int
seq
=
tid
/
lastdim
;
z
[
index
]
=
1
/
(
1
+
std
::
exp
(
-
wz_x_e
[
index
]));
output
[
tid
]
=
input
[
map
[
seq
]
*
lastdim
+
tid
%
lastdim
];
hidden
[
index
]
=
(
1.
-
z
[
index
])
*
tilde
[
index
];
}
}
}
}
template
<
typename
T
>
template
<
typename
Dtype
>
__global__
void
PostComputeKernel
(
const
int
start
,
void
trans_map2out_cfunc
(
const
Dtype
*
input
,
const
int
end
,
Dtype
*
output
,
const
int
cap_h
,
int
word_size
,
const
int
w_tm1
,
int
seq_sum
,
const
T
*
wr_x_e
,
cudaStream_t
stream
,
const
T
*
ur_x_h
,
int
*
dev_map_vec
)
{
const
T
*
wz_x_e
,
int
count
=
seq_sum
*
word_size
;
const
T
*
uz_x_h
,
int
block_dim
=
count
;
const
T
*
w_x_e
,
int
grid_dim
=
1
;
const
T
*
u_x_h
,
T
*
r
,
if
(
count
>
1024
)
{
T
*
z
,
block_dim
=
256
;
T
*
tilde
,
grid_dim
=
(
count
+
block_dim
-
1
)
/
block_dim
;
T
*
hidden
)
{
int
j
=
start
+
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
j
<
end
)
{
r
[
j
]
=
1
/
(
1
+
std
::
exp
(
-
(
wr_x_e
[
j
]
+
ur_x_h
[
j
])));
z
[
j
]
=
1
/
(
1
+
std
::
exp
(
-
(
wz_x_e
[
j
]
+
uz_x_h
[
j
])));
tilde
[
j
]
=
std
::
tanh
(
w_x_e
[
j
]
+
r
[
j
]
*
u_x_h
[
j
]);
hidden
[
j
]
=
z
[
j
]
*
hidden
[
j
-
cap_h
*
w_tm1
]
+
(
1.0
-
z
[
j
])
*
tilde
[
j
];
}
}
trans_map2out
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
dev_map_vec
,
count
,
word_size
);
}
}
void
SearchGrnnCompute
::
PrepareForRun
()
{
template
<
typename
Dtype
>
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
float
,
float
>
);
void
trans_map2in_cfunc
(
const
Dtype
*
input
,
Dtype
*
output
,
int
hidden_size
,
int
seq_sum
,
cudaStream_t
stream
,
int
*
dev_map_vec
)
{
int
count
=
seq_sum
*
hidden_size
;
int
block_dim
=
count
;
int
grid_dim
=
1
;
if
(
count
>
1024
)
{
block_dim
=
256
;
grid_dim
=
(
count
+
block_dim
-
1
)
/
block_dim
;
}
trans_map2in
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
dev_map_vec
,
count
,
hidden_size
);
}
}
void
SearchGrnnCompute
::
PrepareLayout
(
const
Tensor
*
input_blob
)
{
template
<
typename
Dtype
>
auto
&
param
=
this
->
Param
<
param_t
>
();
void
SeqSortedseqTranseUtil
::
seq_2_sorted_seq
(
const
Dtype
*
input
,
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
Dtype
*
output
,
auto
cuda_stream
=
context
.
exec_stream
();
int
word_size
,
cudaStream_t
stream
)
{
int
seq_sum
=
_map_vec
.
size
();
trans_map2out_cfunc
(
input
,
output
,
word_size
,
seq_sum
,
stream
,
_dev_map_vec
);
}
template
<
typename
Dtype
>
void
SeqSortedseqTranseUtil
::
sorted_seq_2_seq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
hidden_size
,
cudaStream_t
stream
)
{
int
seq_sum
=
_map_vec
.
size
();
trans_map2in_cfunc
(
input
,
output
,
hidden_size
,
seq_sum
,
stream
,
_dev_map_vec
);
}
bool
SeqSortedseqTranseUtil
::
get_sorted_map
(
const
std
::
vector
<
int
>&
offset_vec
,
cudaStream_t
stream_id
)
{
int
batch_size
=
offset_vec
.
size
()
-
1
;
int
word_sum
=
offset_vec
[
offset_vec
.
size
()
-
1
];
std
::
vector
<
int
>
length_vec
(
batch_size
);
_length_index
.
resize
(
batch_size
);
int
emit_length
=
0
;
if
(
batch_size
==
1
)
{
emit_length
=
offset_vec
[
1
]
-
offset_vec
[
0
];
_emit_offset_vec
.
resize
(
emit_length
+
1
);
for
(
int
i
=
0
;
i
<=
emit_length
;
++
i
)
{
_emit_offset_vec
[
i
]
=
i
;
}
auto
*
_input
=
input_blob
;
return
false
;
int
dim0
=
_input
->
dims
()[
0
];
int
dim1
=
1
;
if
(
_input
->
dims
().
size
()
>
1
)
{
dim1
=
_input
->
dims
()[
1
];
}
}
int
batch
=
_input
->
lod
()[
0
].
size
()
-
1
;
auto
&
offset
=
_input
->
lod
()[
0
];
int
max_len
=
0
;
idx_sorted_by_width_cpu
=
std
::
make_shared
<
Tensor
>
();
for
(
int
i
=
0
;
i
<
offset_vec
.
size
()
-
1
;
++
i
)
{
idx_sorted_by_width_cpu
->
Resize
({
batch
});
int
len
=
offset_vec
[
i
+
1
]
-
offset_vec
[
i
];
int
*
idx_sorted_by_width_cpu_data
=
max_len
=
max_len
>
len
?
max_len
:
len
;
idx_sorted_by_width_cpu
->
mutable_data
<
int
>
();
length_vec
[
i
]
=
len
;
_length_index
[
i
]
=
i
;
Tensor
_width
;
}
_width
.
Resize
({
batch
});
int
*
width_data
=
_width
.
mutable_data
<
int
>
();
emit_length
=
max_len
;
// sort sequence by width (descending) and find the largest width in the
// batch
if
(
max_len
==
1
)
{
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
_emit_offset_vec
.
resize
(
2
);
width_data
[
i
]
=
offset
[
i
+
1
]
-
offset
[
i
];
_emit_offset_vec
[
0
]
=
0
;
idx_sorted_by_width_cpu_data
[
i
]
=
i
;
_emit_offset_vec
[
1
]
=
emit_length
*
batch_size
;
return
false
;
}
}
std
::
sort
(
idx_sorted_by_width_cpu_data
,
idx_sorted_by_width_cpu_data
+
batch
,
std
::
sort
(
_length_index
.
begin
(),
[
&
_width
](
int
a
,
int
b
)
{
_length_index
.
end
(),
return
_width
.
data
<
int
>
()[
a
]
>
_width
.
data
<
int
>
()[
b
];
[
&
length_vec
](
int
i1
,
int
i2
)
{
return
length_vec
[
i1
]
>
length_vec
[
i2
];
});
});
int
max_width
=
width_data
[
idx_sorted_by_width_cpu_data
[
0
]];
_emit_offset_vec
.
resize
(
max_len
+
1
);
// start of reorganizing the input
_map_vec
.
resize
(
word_sum
);
std
::
vector
<
size_t
>
new_offset
;
new_offset
.
resize
(
max_width
+
1
);
if
(
word_sum
>
_dev_map_vec_length
)
{
new_offset
[
0
]
=
0
;
if
(
_dev_map_vec
!=
nullptr
)
{
int
j
=
batch
-
1
;
TargetWrapperCuda
::
Free
(
static_cast
<
void
*>
(
_dev_map_vec
));
int
last_width
=
0
;
}
int
sub_row
=
0
;
int
sub_col
=
0
;
_dev_map_vec
=
static_cast
<
int
*>
(
TargetWrapperCuda
::
Malloc
(
sizeof
(
int
)
*
word_sum
));
for
(
int
i
=
1
;
i
<=
max_width
;)
{
_dev_map_vec_length
=
word_sum
;
for
(
int
k
=
j
;
k
>=
0
;
--
k
)
{
}
if
(
width_data
[
idx_sorted_by_width_cpu_data
[
k
]]
>
last_width
)
{
sub_row
=
width_data
[
idx_sorted_by_width_cpu_data
[
k
]]
-
last_width
;
int
target_word_id
=
0
;
sub_col
=
k
+
1
;
std
::
vector
<
int
>
length_vec_cnt
=
length_vec
;
for
(
int
s
=
0
;
s
<
sub_row
;
s
++
)
{
int
last_batch_size
=
batch_size
;
new_offset
[
i
]
=
new_offset
[
i
-
1
]
+
sub_col
;
for
(
int
word_id_in_seq
=
0
;
word_id_in_seq
<
max_len
;
word_id_in_seq
++
)
{
i
++
;
_emit_offset_vec
[
word_id_in_seq
]
=
target_word_id
;
for
(
int
batch_id
=
0
;
batch_id
<
last_batch_size
;
batch_id
++
)
{
int
old_batch_id
=
_length_index
[
batch_id
];
if
(
length_vec_cnt
[
old_batch_id
]
>
0
)
{
int
inner_word_id_in_seq
=
word_id_in_seq
;
if
(
_is_reverse
)
{
inner_word_id_in_seq
=
length_vec
[
old_batch_id
]
-
1
-
word_id_in_seq
;
}
}
// move on
last_width
=
width_data
[
idx_sorted_by_width_cpu_data
[
k
]];
int
old_word_id
=
offset_vec
[
old_batch_id
]
+
inner_word_id_in_seq
;
j
=
k
-
1
;
_map_vec
[
old_word_id
]
=
target_word_id
;
length_vec_cnt
[
old_batch_id
]
--
;
target_word_id
++
;
}
else
{
last_batch_size
--
;
break
;
break
;
}
}
}
}
}
}
// copying to the reorganized buffer
TargetWrapperCuda
::
MemcpyAsync
(
_dev_map_vec
,
auto
*
_layout_input
=
new
Tensor
();
_map_vec
.
data
(),
auto
*
_layout_input_gpu
=
param
.
layout_input
;
sizeof
(
int
)
*
word_sum
,
if
(
_input
->
dims
().
size
()
==
1
)
{
IoDirection
::
HtoD
,
// _layout_input.reshape_batch_sequence({dim0}, new_offset);
stream_id
);
LOG
(
FATAL
)
<<
"_input->dims().size() = 1, error."
;
_emit_offset_vec
[
max_len
]
=
word_sum
;
}
else
{
_emit_length
=
emit_length
;
// _layout_input.reshape_batch_sequence({dim0, dim1}, new_offset);
return
true
;
LoD
new_lod
;
}
new_lod
.
push_back
(
new_offset
);
_layout_input
->
set_lod
(
new_lod
);
_layout_input
->
Resize
({
dim0
,
dim1
});
_layout_input_gpu
->
set_lod
(
new_lod
);
_layout_input_gpu
->
Resize
({
dim0
,
dim1
});
}
auto
*
new_emb
=
_layout_input
->
mutable_data
<
float
>
();
template
<
typename
Dtype
>
auto
*
input_cpu
=
new
Tensor
();
__global__
void
transpose_2d
(
Dtype
*
output
,
const
Dtype
*
input
,
int
m
,
int
n
)
{
input_cpu
->
Resize
(
_input
->
dims
());
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
auto
*
input_cpu_data
=
input_cpu
->
mutable_data
<
float
>
();
if
(
tid
<
m
*
n
)
{
TargetW
::
MemcpyAsync
(
input_cpu_data
,
int
i
=
tid
/
n
;
_input
->
data
<
float
>
(),
int
j
=
tid
%
m
;
_input
->
numel
()
*
sizeof
(
float
),
output
[
tid
]
=
input
[
j
*
n
+
i
];
IoDirection
::
DtoH
,
cuda_stream
);
for
(
int
i
=
0
;
i
<
max_width
;
i
++
)
{
int
w
=
new_offset
[
i
+
1
]
-
new_offset
[
i
];
auto
*
emb_start
=
new_emb
+
dim1
*
new_offset
[
i
];
for
(
int
j
=
0
;
j
<
w
;
++
j
)
{
memcpy
(
emb_start
+
dim1
*
j
,
input_cpu_data
+
dim1
*
offset
[
idx_sorted_by_width_cpu_data
[
j
]]
+
dim1
*
i
,
dim1
*
sizeof
(
float
));
}
}
}
}
void
SearchGrnnCompute
::
WeightsPreprocess
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
context
.
exec_stream
();
auto
*
_layout_input_gpu_data
=
DDim
idims
=
param
.
wi
->
dims
();
_layout_input_gpu
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
DDim
hdims
=
param
.
wh
->
dims
();
TargetW
::
MemcpyAsync
(
_layout_input_gpu_data
,
_wi
.
Resize
({
idims
[
2
],
idims
[
0
],
idims
[
1
]});
new_emb
,
_wh
.
Resize
({
hdims
[
2
],
hdims
[
0
],
hdims
[
1
]});
_layout_input
->
numel
()
*
sizeof
(
float
),
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
IoDirection
::
HtoD
,
trans
.
transpose
(
_wi
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
cuda_stream
);
param
.
wi
->
data
<
float
>
(),
delete
_layout_input
;
idims
.
Vectorize
(),
delete
input_cpu
;
{
2
,
0
,
1
},
&
stream
);
trans
.
transpose
(
_wh
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
))
+
hdims
[
1
]
*
hdims
[
2
],
param
.
wh
->
data
<
float
>
()
+
hdims
[
1
]
*
hdims
[
2
],
{
hdims
[
0
]
-
1
,
hdims
[
1
],
hdims
[
2
]},
{
2
,
0
,
1
},
&
stream
);
trans
.
transpose
(
_wh
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
param
.
wh
->
data
<
float
>
(),
{
hdims
[
1
],
hdims
[
2
]},
{
1
,
0
},
&
stream
);
// int thread_num = 512;
// int block_num = (hdims[1] * hdims[2] + thread_num - 1) / thread_num;
// transpose_2d<<<block_num, thread_num, 0, stream>>>(
// _wh.mutable_data<float>(TARGET(kCUDA)),
// param.wh->data<float>(),
// hdims[1],
// hdims[2]);
}
}
void
SearchGrnnCompute
::
CopyBack
(
float
*
from
,
float
*
to
,
int
step
)
{
void
SearchGrnnCompute
::
PrepareForRun
(
)
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
context
.
exec_stream
();
auto
stream
=
context
.
exec_stream
();
auto
*
_input
=
param
.
x
;
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
float
,
float
>
);
auto
*
_layout_input
=
param
.
layout_input
;
_seq_util
=
SeqSortedseqTranseUtil
();
const
auto
&
offset
=
_input
->
lod
()[
0
];
WeightsPreprocess
();
const
auto
&
new_offset
=
_layout_input
->
lod
()[
0
];
const
auto
*
idx_sorted_by_width_cpu_data
=
int
hidden_size
=
param
.
num_hidden
;
idx_sorted_by_width_cpu
->
data
<
int
>
();
int
word_size
=
param
.
num_input
;
for
(
size_t
i
=
0
;
i
<
_layout_input
->
lod
()[
0
].
size
()
-
1
;
++
i
)
{
int
weights_h2h_size
=
hidden_size
*
hidden_size
*
3
;
int
w
=
new_offset
[
i
+
1
]
-
new_offset
[
i
];
int
weights_i2h_size
=
hidden_size
*
word_size
*
3
;
for
(
int
j
=
0
;
j
<
w
;
j
++
)
{
TargetW
::
MemcpyAsync
(
lite
::
Tensor
temp_weights_h2h_ori
;
to
+
step
*
(
offset
[
idx_sorted_by_width_cpu_data
[
j
]]
+
i
),
lite
::
Tensor
temp_weights_h2h_swarp
;
from
+
(
new_offset
[
i
]
+
j
)
*
step
,
temp_weights_h2h_ori
.
Resize
({
weights_h2h_size
});
step
*
sizeof
(
float
),
temp_weights_h2h_swarp
.
Resize
({
weights_h2h_size
});
IoDirection
::
DtoD
,
stream
);
TargetWrapperCuda
::
MemcpyAsync
(
temp_weights_h2h_ori
.
mutable_data
<
float
>
(),
_wh
.
data
<
float
>
(),
sizeof
(
float
)
*
weights_h2h_size
,
IoDirection
::
DtoH
,
stream
);
cudaStreamSynchronize
(
stream
);
float
*
temp_tensor_ptr
=
temp_weights_h2h_swarp
.
mutable_data
<
float
>
();
memcpy
(
temp_tensor_ptr
,
temp_weights_h2h_ori
.
data
<
float
>
(),
sizeof
(
float
)
*
hidden_size
*
hidden_size
);
float
*
rz_temp_tensor_ptr
=
temp_tensor_ptr
+
hidden_size
*
hidden_size
;
const
float
*
rz_weights_tensor_ptr
=
temp_weights_h2h_ori
.
data
<
float
>
()
+
hidden_size
*
hidden_size
;
for
(
int
row
=
0
;
row
<
hidden_size
;
row
++
)
{
for
(
int
block
=
0
;
block
<
2
;
block
++
)
{
int
block_offset
=
block
*
hidden_size
;
for
(
int
cow
=
0
;
cow
<
hidden_size
;
cow
++
)
{
rz_temp_tensor_ptr
[
block
*
hidden_size
*
hidden_size
+
row
*
hidden_size
+
cow
]
=
rz_weights_tensor_ptr
[
row
*
(
2
*
hidden_size
)
+
cow
+
block_offset
];
}
}
}
float
*
orz_temp_tensor_ptr
=
temp_tensor_ptr
;
float
*
orz_weights_tensor_ptr
=
temp_weights_h2h_ori
.
mutable_data
<
float
>
();
for
(
int
row
=
0
;
row
<
hidden_size
;
row
++
)
{
for
(
int
block
=
0
;
block
<
3
;
block
++
)
{
int
block_offset
=
block
*
hidden_size
;
for
(
int
cow
=
0
;
cow
<
hidden_size
;
cow
++
)
{
orz_weights_tensor_ptr
[
row
*
(
3
*
hidden_size
)
+
cow
+
block_offset
]
=
orz_temp_tensor_ptr
[
block
*
hidden_size
*
hidden_size
+
row
*
hidden_size
+
cow
];
}
}
}
}
}
_temp_weights_h2h
.
Resize
({
weights_h2h_size
});
TargetWrapperCuda
::
MemcpyAsync
(
_temp_weights_h2h
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
temp_weights_h2h_ori
.
data
<
float
>
(),
sizeof
(
float
)
*
weights_h2h_size
,
IoDirection
::
HtoD
,
stream
);
cudaStreamSynchronize
(
stream
);
}
template
<
typename
Dtype
>
static
inline
__device__
Dtype
Sigmoid
(
const
Dtype
a
)
{
return
static_cast
<
Dtype
>
(
1.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
-
a
));
}
template
<
typename
Dtype
>
static
inline
__device__
Dtype
Tanh
(
const
Dtype
a
)
{
Dtype
tmp
=
static_cast
<
Dtype
>
(
-
2.0
)
*
a
;
return
(
static_cast
<
Dtype
>
(
2.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
tmp
)))
-
static_cast
<
Dtype
>
(
1.0
);
}
template
<
typename
Dtype
>
__global__
void
cal_cudnn_kernel
(
const
Dtype
*
w_x_r
,
const
Dtype
*
w_x_z
,
const
Dtype
*
w_x_o
,
const
Dtype
*
w_h_r
,
const
Dtype
*
w_h_z
,
const
Dtype
*
w_h_o
,
int
hidden_size
,
int
batch_size
,
Dtype
*
output
,
const
Dtype
*
hidden_pre
)
{
const
int
thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
batch_id
=
thread_id
/
hidden_size
;
const
int
index
=
thread_id
%
hidden_size
;
if
(
index
<
hidden_size
&&
batch_id
<
batch_size
)
{
int
w_base_index
=
batch_id
*
hidden_size
*
3
+
index
;
int
h_base_index
=
batch_id
*
hidden_size
+
index
;
Dtype
hidden_pre_value
=
hidden_pre
[
h_base_index
];
Dtype
r
=
Sigmoid
(
w_x_r
[
w_base_index
]
+
w_h_r
[
w_base_index
]);
Dtype
z
=
Sigmoid
(
w_x_z
[
w_base_index
]
+
w_h_z
[
w_base_index
]);
Dtype
_h
=
Tanh
(
w_x_o
[
w_base_index
]
+
w_h_o
[
w_base_index
]
*
r
);
output
[
h_base_index
]
=
(
static_cast
<
Dtype
>
(
1.0
)
-
z
)
*
_h
+
z
*
hidden_pre_value
;
}
}
}
void
SearchGrnnCompute
::
Run
()
{
void
SearchGrnnCompute
::
Run
()
{
CHECK
(
ctx_
)
<<
"running context should be set first"
;
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
context
.
exec_stream
();
auto
stream
=
context
.
exec_stream
();
auto
*
bottom
=
param
.
x
;
auto
*
x
=
param
.
x
;
auto
*
wi
=
param
.
wi
;
LoD
offset_vec_vec
=
x
->
lod
();
auto
*
wh
=
param
.
wh
;
std
::
vector
<
int
>
offset
(
offset_vec_vec
[
offset_vec_vec
.
size
()
-
1
].
size
());
auto
*
top
=
param
.
out
;
for
(
size_t
i
=
0
;
i
<
offset_vec_vec
[
offset_vec_vec
.
size
()
-
1
].
size
();
auto
*
_buffer
=
param
.
tmp_buffer
;
++
i
)
{
int
_cap_h
=
param
.
num_hidden
;
offset
[
i
]
=
static_cast
<
int
>
(
offset_vec_vec
[
offset_vec_vec
.
size
()
-
1
][
i
]);
int
_cap_e
=
param
.
num_input
;
}
const
float
*
x_data
=
x
->
data
<
float
>
();
int
_cap_l
=
bottom
->
dims
()[
0
];
auto
*
dout
=
param
.
out
;
int
batch
=
bottom
->
lod
()[
0
].
size
()
-
1
;
std
::
vector
<
int64_t
>
out_dims_vec
{
x
->
dims
()[
0
],
param
.
num_hidden
};
dout
->
Resize
(
out_dims_vec
);
const
auto
&
offset
=
bottom
->
lod
()[
0
];
float
*
dout_data
=
dout
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
LoD
top_lod
;
auto
*
wi
=
&
_wi
;
top_lod
.
push_back
(
offset
);
auto
*
wh
=
&
_wh
;
top
->
set_lod
(
top_lod
);
std
::
vector
<
int64_t
>
top_dims_vec
{
_cap_l
,
_cap_h
};
const
float
*
weights_i2h
=
wi
->
data
<
float
>
();
top
->
Resize
(
top_dims_vec
);
const
float
*
weights_h2h
=
wh
->
data
<
float
>
();
auto
*
top_hidden
=
top
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
int
batch_size
=
offset
.
size
()
-
1
;
const
auto
*
dense_e2h
=
wi
->
data
<
float
>
();
int
seq_sum
=
x
->
dims
()[
0
];
const
auto
*
dense_h2h
=
wh
->
data
<
float
>
();
bool
is_batched
=
offset
.
size
()
>
2
;
int
hidden_size
=
param
.
num_hidden
;
const
auto
*
e2h
=
dense_e2h
;
int
word_size
=
param
.
num_input
;
const
auto
*
e2hr
=
dense_e2h
+
1
*
_cap_e
*
_cap_h
;
int
o_offset
=
0
;
const
auto
*
e2hz
=
dense_e2h
+
2
*
_cap_e
*
_cap_h
;
int
r_offset
=
1
;
const
auto
*
h2h
=
dense_h2h
;
int
z_offset
=
2
;
const
auto
*
h2hr
=
dense_h2h
+
1
*
_cap_h
*
_cap_h
;
const
auto
*
h2hz
=
dense_h2h
+
2
*
_cap_h
*
_cap_h
;
is_batched
=
_seq_util
.
get_sorted_map
(
offset
,
stream
);
std
::
vector
<
int
>
emit_offset_vec
=
_seq_util
.
get_emit_offset_vec
();
PrepareLayout
(
bottom
);
int
emit_length
=
emit_offset_vec
.
size
()
-
1
;
auto
*
_layout_input
=
param
.
layout_input
;
if
(
is_batched
)
{
auto
*
new_emb
=
_layout_input
->
data
<
float
>
();
std
::
vector
<
int64_t
>
seq_shape
{
1
,
1
,
seq_sum
,
word_size
};
const
auto
&
new_offset
=
_layout_input
->
lod
()[
0
];
_temp_tensor_in
.
Resize
(
seq_shape
);
int
max_width
=
_layout_input
->
lod
()[
0
].
size
()
-
1
;
std
::
vector
<
int64_t
>
seq_out_shape
{
1
,
1
,
seq_sum
,
hidden_size
};
_temp_tensor_out
.
Resize
(
seq_out_shape
);
// this buffer is used for book keeping info which will be used in bp
_seq_util
.
seq_2_sorted_seq
(
// buffer also needed in bp, so make it larger
x_data
,
_buffer
->
Resize
({
20
,
_cap_l
,
_cap_h
});
_temp_tensor_in
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
auto
*
buffer_data
=
_buffer
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
word_size
,
auto
*
w_x_e
=
buffer_data
+
0
*
_cap_l
*
_cap_h
;
stream
);
auto
*
wr_x_e
=
buffer_data
+
1
*
_cap_l
*
_cap_h
;
x_data
=
_temp_tensor_in
.
data
<
float
>
();
auto
*
wz_x_e
=
buffer_data
+
2
*
_cap_l
*
_cap_h
;
dout_data
=
_temp_tensor_out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
u_x_h
=
buffer_data
+
3
*
_cap_l
*
_cap_h
;
}
auto
*
ur_x_h
=
buffer_data
+
4
*
_cap_l
*
_cap_h
;
auto
*
uz_x_h
=
buffer_data
+
5
*
_cap_l
*
_cap_h
;
std
::
vector
<
int64_t
>
shape_wx
({
seq_sum
,
1
,
3
,
hidden_size
});
auto
*
r
=
buffer_data
+
6
*
_cap_l
*
_cap_h
;
_temp_wx
.
Resize
(
shape_wx
);
auto
*
z
=
buffer_data
+
7
*
_cap_l
*
_cap_h
;
auto
*
tilde
=
buffer_data
+
8
*
_cap_l
*
_cap_h
;
std
::
vector
<
int64_t
>
shape_wh
({
1
,
batch_size
,
3
,
hidden_size
});
// the internal hidden
_temp_wh
.
Resize
(
shape_wh
);
auto
*
hidden
=
buffer_data
+
19
*
_cap_l
*
_cap_h
;
gemm_impl_
->
init
(
false
,
false
,
seq_sum
,
3
*
hidden_size
,
word_size
,
&
context
);
gemm_impl_
->
init
(
false
,
true
,
_cap_l
,
_cap_h
,
_cap_e
,
&
context
);
gemm_impl_
->
run
(
1.0
f
,
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
new_emb
,
e2h
,
w_x_e
,
&
context
);
0.0
f
,
gemm_impl_
->
init
(
false
,
true
,
_cap_l
,
_cap_h
,
_cap_e
,
&
context
);
x_data
,
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
new_emb
,
e2hr
,
wr_x_e
,
&
context
);
weights_i2h
,
gemm_impl_
->
init
(
false
,
true
,
_cap_l
,
_cap_h
,
_cap_e
,
&
context
);
_temp_wx
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
new_emb
,
e2hz
,
wz_x_e
,
&
context
);
&
context
);
// precompute hidden0
std
::
vector
<
int64_t
>
shape_zero
({
batch_size
*
hidden_size
});
int
num
=
batch
*
_cap_h
;
_temp_zero
.
Resize
(
shape_zero
);
int
threads
=
512
;
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
TargetWrapperCuda
::
MemsetAsync
(
_temp_zero
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
PreComputeKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
0
,
num
,
w_x_e
,
wz_x_e
,
tilde
,
z
,
hidden
);
sizeof
(
float
)
*
batch_size
*
hidden_size
,
stream
);
// recurrence
for
(
int
i
=
1
;
i
<
max_width
;
i
++
)
{
const
float
*
h
=
_temp_zero
.
data
<
float
>
();
int
w_tm1
=
new_offset
[
i
]
-
new_offset
[
i
-
1
];
for
(
int
word_id
=
0
;
word_id
<
emit_length
;
word_id
++
)
{
int
w
=
new_offset
[
i
+
1
]
-
new_offset
[
i
];
int
real_word_id
=
word_id
;
int
last_word_id
=
word_id
-
1
;
// precompute hidden i-1 to hidden i
int
emit_word_id_start
=
emit_offset_vec
[
real_word_id
];
auto
*
htm1
=
hidden
+
new_offset
[
i
-
1
]
*
_cap_h
;
int
emit_word_id_end
=
emit_offset_vec
[
real_word_id
+
1
];
int
emit_word_length
=
emit_word_id_end
-
emit_word_id_start
;
gemm_impl_
->
init
(
false
,
true
,
w
,
_cap_h
,
_cap_h
,
&
context
);
gemm_impl_
->
run
(
const
float
*
hidden_in
;
1.0
f
,
0.0
f
,
htm1
,
h2h
,
u_x_h
+
new_offset
[
i
]
*
_cap_h
,
&
context
);
float
*
hidden_out
=
dout_data
+
emit_offset_vec
[
real_word_id
]
*
hidden_size
;
gemm_impl_
->
init
(
false
,
true
,
w
,
_cap_h
,
_cap_h
,
&
context
);
gemm_impl_
->
run
(
if
(
word_id
==
0
)
{
1.0
f
,
0.0
f
,
htm1
,
h2hr
,
ur_x_h
+
new_offset
[
i
]
*
_cap_h
,
&
context
);
hidden_in
=
h
;
gemm_impl_
->
init
(
false
,
true
,
w
,
_cap_h
,
_cap_h
,
&
context
);
}
else
{
gemm_impl_
->
run
(
hidden_in
=
dout_data
+
emit_offset_vec
[
last_word_id
]
*
hidden_size
;
1.0
f
,
0.0
f
,
htm1
,
h2hz
,
uz_x_h
+
new_offset
[
i
]
*
_cap_h
,
&
context
);
}
// compute the gate and hidden
float
*
w_x_r
=
_temp_wx
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
))
+
int
start
=
new_offset
[
i
]
*
_cap_h
;
r_offset
*
hidden_size
+
int
end
=
(
new_offset
[
i
]
+
w
)
*
_cap_h
;
emit_word_id_start
*
hidden_size
*
3
;
PostComputeKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
start
,
float
*
w_x_z
=
_temp_wx
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
))
+
end
,
z_offset
*
hidden_size
+
_cap_h
,
emit_word_id_start
*
hidden_size
*
3
;
w_tm1
,
float
*
w_x_o
=
_temp_wx
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
))
+
wr_x_e
,
o_offset
*
hidden_size
+
ur_x_h
,
emit_word_id_start
*
hidden_size
*
3
;
wz_x_e
,
uz_x_h
,
float
*
w_h_r
=
w_x_e
,
_temp_wh
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
))
+
r_offset
*
hidden_size
;
u_x_h
,
float
*
w_h_z
=
r
,
_temp_wh
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
))
+
z_offset
*
hidden_size
;
z
,
float
*
w_h_o
=
tilde
,
_temp_wh
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
))
+
o_offset
*
hidden_size
;
hidden
);
gemm_impl_
->
init
(
false
,
false
,
emit_word_length
,
3
*
hidden_size
,
hidden_size
,
&
context
);
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
hidden_in
,
_temp_weights_h2h
.
data
<
float
>
(),
_temp_wh
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
&
context
);
const
float
*
w_o
=
weights_h2h
;
const
int
block_dim
=
512
;
const
int
grid_dim
=
(
emit_word_length
*
hidden_size
+
block_dim
-
1
)
/
block_dim
;
cal_cudnn_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
w_x_r
,
w_x_z
,
w_x_o
,
w_h_r
,
w_h_z
,
w_h_o
,
hidden_size
,
emit_word_length
,
hidden_out
,
hidden_in
);
}
if
(
is_batched
)
{
_seq_util
.
sorted_seq_2_seq
(
_temp_tensor_out
.
data
<
float
>
(),
dout
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
)),
hidden_size
,
stream
);
}
}
CopyBack
(
hidden
,
top_hidden
,
_cap_h
);
dout
->
set_lod
(
x
->
lod
()
);
}
}
}
// namespace cuda
}
// namespace cuda
...
...
lite/kernels/cuda/search_grnn_compute.h
浏览文件 @
b65a6dc9
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <memory>
#include <memory>
#include <vector>
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
#include "lite/core/kernel.h"
...
@@ -23,6 +24,53 @@ namespace lite {
...
@@ -23,6 +24,53 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
cuda
{
namespace
cuda
{
class
SeqSortedseqTranseUtil
{
public:
explicit
SeqSortedseqTranseUtil
(
bool
is_reverse
=
false
,
bool
is_bi
=
false
)
:
_is_reverse
(
is_reverse
),
_is_bi
(
is_bi
),
_dev_map_vec
(
nullptr
),
_dev_map_vec_length
(
0
)
{}
~
SeqSortedseqTranseUtil
()
{
if
(
_dev_map_vec
!=
nullptr
)
{
TargetWrapperCuda
::
Free
(
static_cast
<
void
*>
(
_dev_map_vec
));
}
}
std
::
vector
<
int
>&
get_length_index
()
{
return
_length_index
;
}
std
::
vector
<
int
>&
get_emit_offset_vec
()
{
return
_emit_offset_vec
;
}
std
::
vector
<
int
>&
get_map_vec
()
{
return
_map_vec
;
}
int
*
get_dev_map_vec
()
{
return
_dev_map_vec
;
}
int
get_emit_length
()
{
return
_emit_length
;
}
template
<
typename
Dtype
>
void
seq_2_sorted_seq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
word_size
,
cudaStream_t
stream
);
template
<
typename
Dtype
>
void
sorted_seq_2_seq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
hidden_size
,
cudaStream_t
stream
);
bool
get_sorted_map
(
const
std
::
vector
<
int
>&
offset_vec
,
cudaStream_t
stream_id
);
private:
std
::
vector
<
int
>
_length_index
;
std
::
vector
<
int
>
_emit_offset_vec
;
std
::
vector
<
int
>
_map_vec
;
int
_emit_length
;
bool
_is_reverse
;
bool
_is_bi
;
int
*
_dev_map_vec
;
int
_dev_map_vec_length
;
};
class
SearchGrnnCompute
class
SearchGrnnCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
{
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
{
public:
public:
...
@@ -34,10 +82,26 @@ class SearchGrnnCompute
...
@@ -34,10 +82,26 @@ class SearchGrnnCompute
virtual
~
SearchGrnnCompute
()
=
default
;
virtual
~
SearchGrnnCompute
()
=
default
;
private:
private:
std
::
shared_ptr
<
Tensor
>
idx_sorted_by_width_cpu
;
// Weights preprocess:
// wi need to be transpose, the axes should be (2, 0, 1)
// wh0 should transpose, {wh1 wh2} need be transpose, the axes should be {2,
// 0, 1}
void
WeightsPreprocess
();
private:
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
float
,
float
>>
gemm_impl_
;
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
float
,
float
>>
gemm_impl_
;
void
PrepareLayout
(
const
Tensor
*
input
);
void
CopyBack
(
float
*
from
,
float
*
to
,
int
step
);
lite
::
Tensor
_temp_tensor_in
;
lite
::
Tensor
_temp_tensor_out
;
lite
::
Tensor
_temp_wx
;
lite
::
Tensor
_temp_wh
;
lite
::
Tensor
_temp_zero
;
lite
::
Tensor
_temp_weights_h2h
;
lite
::
Tensor
_wi
;
lite
::
Tensor
_wh
;
SeqSortedseqTranseUtil
_seq_util
;
};
};
}
// namespace cuda
}
// namespace cuda
...
...
lite/kernels/cuda/transpose_compute.cu
浏览文件 @
b65a6dc9
...
@@ -25,6 +25,7 @@ namespace cuda {
...
@@ -25,6 +25,7 @@ namespace cuda {
void
TransposeCompute
::
Run
()
{
void
TransposeCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
const
lite
::
Tensor
*
X
=
param
.
x
;
const
lite
::
Tensor
*
X
=
param
.
x
;
lite
::
Tensor
*
Out
=
param
.
output
;
lite
::
Tensor
*
Out
=
param
.
output
;
...
@@ -39,8 +40,7 @@ void TransposeCompute::Run() {
...
@@ -39,8 +40,7 @@ void TransposeCompute::Run() {
// NCHW -> NHWC
// NCHW -> NHWC
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
2
&&
axes
[
2
]
==
3
&&
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
2
&&
axes
[
2
]
==
3
&&
axes
[
3
]
==
1
)
{
axes
[
3
]
==
1
)
{
lite
::
cuda
::
math
::
NCHW2NHWC
(
trans
.
NCHW2NHWC
(
dims
[
0
],
dims
[
1
],
dims
[
2
]
*
dims
[
3
],
in
,
out
,
&
stream
);
dims
[
0
],
dims
[
1
],
dims
[
2
]
*
dims
[
3
],
in
,
out
,
&
ctx
);
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
return
;
return
;
...
@@ -49,14 +49,13 @@ void TransposeCompute::Run() {
...
@@ -49,14 +49,13 @@ void TransposeCompute::Run() {
// NHWC -> NCHW
// NHWC -> NCHW
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
3
&&
axes
[
2
]
==
1
&&
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
3
&&
axes
[
2
]
==
1
&&
axes
[
3
]
==
2
)
{
axes
[
3
]
==
2
)
{
lite
::
cuda
::
math
::
NHWC2NCHW
(
trans
.
NHWC2NCHW
(
dims
[
0
],
dims
[
3
],
dims
[
1
]
*
dims
[
2
],
in
,
out
,
&
stream
);
dims
[
0
],
dims
[
3
],
dims
[
1
]
*
dims
[
2
],
in
,
out
,
&
ctx
);
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
return
;
return
;
}
}
lite
::
cuda
::
math
::
Transpose
(
dims
,
axes
,
in
,
out
,
&
ctx
);
trans
.
transpose
(
out
,
in
,
dims
,
axes
,
&
stream
);
cudaError_t
error
=
cudaGetLastError
();
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
}
...
...
lite/kernels/cuda/transpose_compute.h
浏览文件 @
b65a6dc9
...
@@ -29,7 +29,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
...
@@ -29,7 +29,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
virtual
~
TransposeCompute
()
=
default
;
virtual
~
TransposeCompute
()
=
default
;
private:
private:
lite
::
Tensor
axes_
,
dims_
;
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
};
};
}
// namespace cuda
}
// namespace cuda
...
...
lite/kernels/cuda/transpose_compute_test.cc
浏览文件 @
b65a6dc9
...
@@ -238,7 +238,7 @@ TEST(transpose, normal) {
...
@@ -238,7 +238,7 @@ TEST(transpose, normal) {
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
out
,
out_cpu
,
out_ref
;
lite
::
Tensor
out
,
out_cpu
,
out_ref
;
int
C
=
6
,
H
=
7
,
W
=
8
;
int
C
=
3
,
H
=
128
,
W
=
12
8
;
std
::
vector
<
int
>
axes
({
2
,
0
,
1
});
std
::
vector
<
int
>
axes
({
2
,
0
,
1
});
x
.
Resize
({
C
,
H
,
W
});
x
.
Resize
({
C
,
H
,
W
});
out
.
Resize
({
W
,
C
,
H
});
out
.
Resize
({
W
,
C
,
H
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录