Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
ab7b2855
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ab7b2855
编写于
8月 06, 2020
作者:
W
Wilber
提交者:
GitHub
8月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CUDA] [Kernel] Optimize gru. (#4062)
上级
be13a60a
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
448 addition
and
59 deletion
+448
-59
lite/backends/cuda/math/CMakeLists.txt
lite/backends/cuda/math/CMakeLists.txt
+2
-0
lite/backends/cuda/math/cudnn_conv.cc
lite/backends/cuda/math/cudnn_conv.cc
+14
-13
lite/backends/cuda/math/cudnn_helper.cc
lite/backends/cuda/math/cudnn_helper.cc
+1
-11
lite/backends/cuda/math/cudnn_helper.h
lite/backends/cuda/math/cudnn_helper.h
+89
-2
lite/backends/cuda/math/cudnn_softmax.cc
lite/backends/cuda/math/cudnn_softmax.cc
+2
-2
lite/backends/cuda/math/sequence2batch.cu
lite/backends/cuda/math/sequence2batch.cu
+8
-14
lite/backends/cuda/math/sequence_helper.cu
lite/backends/cuda/math/sequence_helper.cu
+215
-0
lite/backends/cuda/math/sequence_helper.h
lite/backends/cuda/math/sequence_helper.h
+77
-0
lite/core/CMakeLists.txt
lite/core/CMakeLists.txt
+1
-1
lite/kernels/cuda/gru_compute.cu
lite/kernels/cuda/gru_compute.cu
+35
-16
lite/kernels/cuda/gru_compute.h
lite/kernels/cuda/gru_compute.h
+2
-0
lite/operators/softmax_op.cc
lite/operators/softmax_op.cc
+2
-0
未找到文件。
lite/backends/cuda/math/CMakeLists.txt
浏览文件 @
ab7b2855
...
...
@@ -20,6 +20,7 @@ nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library
(
cuda_strided_gemm SRCS strided_gemm.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_sequence_padding SRCS sequence_padding.cu DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_bias SRCS bias.cu DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_sequence_helper SRCS sequence_helper.cu DEPS
${
cuda_static_deps
}
)
set
(
math_cuda
...
...
@@ -39,6 +40,7 @@ set (
cuda_sequence_padding
cuda_bias
cudnn_helper
cuda_sequence_helper
)
set
(
math_cuda
"
${
math_cuda
}
"
CACHE GLOBAL
"math cuda"
)
lite/backends/cuda/math/cudnn_conv.cc
浏览文件 @
ab7b2855
...
...
@@ -55,31 +55,32 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
CUDNN_CHECK
(
cudnnSetTensor4dDescriptor
(
this
->
input_desc_
,
CUDNN_TENSOR_NCHW
,
GetCudnnDataType
<
Ptype_out
>
()
,
cudnn
::
cudnnTypeWrapper
<
T
>::
type
,
batch
,
ic
,
ih
,
iw
));
CUDNN_CHECK
(
cudnnSetFilter4dDescriptor
(
this
->
filter_desc_
,
GetCudnnDataType
<
Ptype_out
>
()
,
cudnn
::
cudnnTypeWrapper
<
T
>::
type
,
CUDNN_TENSOR_NCHW
,
oc
,
ic
/
param
.
groups
,
kh
,
kw
));
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
this
->
conv_desc_
,
ph
,
pw
,
sh
,
sw
,
dh
,
dw
,
CUDNN_CROSS_CORRELATION
,
GetCudnnDataType
<
Ptype_out
>
()));
CUDNN_CHECK
(
cudnnSetConvolution2dDescriptor
(
this
->
conv_desc_
,
ph
,
pw
,
sh
,
sw
,
dh
,
dw
,
CUDNN_CROSS_CORRELATION
,
cudnn
::
cudnnTypeWrapper
<
T
>::
type
));
CUDNN_CHECK
(
cudnnSetConvolutionGroupCount
(
this
->
conv_desc_
,
param
.
groups
));
CUDNN_CHECK
(
cudnnSetTensor4dDescriptor
(
this
->
output_desc_
,
CUDNN_TENSOR_NCHW
,
GetCudnnDataType
<
Ptype_out
>
()
,
cudnn
::
cudnnTypeWrapper
<
T
>::
type
,
batch
,
oc
,
oh
,
...
...
@@ -179,7 +180,7 @@ bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
int
dim_bias
[]
=
{
1
,
oc
,
1
,
1
};
int
stride_bias
[]
=
{
oc
,
1
,
1
,
1
};
cudnnSetTensorNdDescriptor
(
this
->
bias_desc_
,
GetCudnnDataType
<
Ptype_out
>
()
,
cudnn
::
cudnnTypeWrapper
<
T
>::
type
,
4
,
dim_bias
,
stride_bias
);
...
...
lite/backends/cuda/math/cudnn_helper.cc
浏览文件 @
ab7b2855
...
...
@@ -21,17 +21,7 @@ namespace paddle {
namespace
lite
{
namespace
cuda
{
namespace
math
{
template
<
>
cudnnDataType_t
GetCudnnDataType
<
PRECISION
(
kFloat
)
>
()
{
return
CUDNN_DATA_FLOAT
;
}
template
<
>
cudnnDataType_t
GetCudnnDataType
<
PRECISION
(
kFP16
)
>
()
{
return
CUDNN_DATA_HALF
;
}
namespace
cudnn
{}
// namespace cudnn
}
// namespace math
}
// namespace cuda
}
// namespace lite
...
...
lite/backends/cuda/math/cudnn_helper.h
浏览文件 @
ab7b2855
...
...
@@ -25,10 +25,97 @@ namespace paddle {
namespace
lite
{
namespace
cuda
{
namespace
math
{
namespace
cudnn
{
template
<
lite_api
::
PrecisionType
PType
>
c
udnnDataType_t
GetCudnnDataType
()
;
template
<
typename
T
>
c
lass
cudnnTypeWrapper
;
template
<
>
class
cudnnTypeWrapper
<
float
>
{
public:
static
const
cudnnDataType_t
type
=
CUDNN_DATA_FLOAT
;
typedef
const
float
ScalingParamType
;
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
v
=
1.0
f
;
return
&
v
;
}
static
ScalingParamType
*
kZero
()
{
static
ScalingParamType
v
=
0.0
f
;
return
&
v
;
}
};
template
<
>
class
cudnnTypeWrapper
<
half
>
{
public:
static
const
cudnnDataType_t
type
=
CUDNN_DATA_HALF
;
typedef
const
half
ScalingParamType
;
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
v
=
__float2half
(
1.0
f
);
return
&
v
;
}
static
ScalingParamType
*
kZero
()
{
static
ScalingParamType
v
=
__float2half
(
0.0
f
);
return
&
v
;
}
};
struct
ParamsRegion
{
ParamsRegion
()
:
offset_
(
nullptr
),
size_
(
0
)
{}
ParamsRegion
(
void
*
offset
,
size_t
size
)
:
offset_
(
offset
),
size_
(
size
)
{}
~
ParamsRegion
()
{}
ParamsRegion
&
operator
=
(
const
ParamsRegion
&
right
)
{
offset_
=
right
.
offset_
;
size_
=
right
.
size_
;
return
*
this
;
}
bool
operator
==
(
const
ParamsRegion
&
right
)
{
bool
comp_eq
=
true
;
comp_eq
=
comp_eq
&&
(
offset_
==
right
.
offset_
);
comp_eq
=
comp_eq
&&
(
size_
=
right
.
size_
);
return
comp_eq
;
}
void
*
offset_
;
size_t
size_
;
};
template
<
typename
T
>
class
TensorDescriptors
{
public:
TensorDescriptors
(
size_t
n
,
const
std
::
vector
<
std
::
vector
<
int
>>&
dim
,
const
std
::
vector
<
std
::
vector
<
int
>>&
stride
)
{
descs_
.
resize
(
n
);
CHECK_EQ
(
dim
.
size
(),
stride
.
size
())
<<
"dim size should be equal to stride size"
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
CUDNN_CHECK
(
cudnnCreateTensorDescriptor
(
&
descs_
[
i
]));
CUDNN_CHECK
(
cudnnSetTensorNdDescriptor
(
descs_
[
i
],
cudnnTypeWrapper
<
T
>::
type
,
dim
[
i
].
size
(),
dim
[
i
].
data
(),
stride
[
i
].
data
()));
}
}
~
TensorDescriptors
()
{
for
(
auto
desc
:
descs_
)
{
CUDNN_CHECK
(
cudnnDestroyTensorDescriptor
(
desc
));
}
}
const
cudnnTensorDescriptor_t
*
descs
()
const
{
return
descs_
.
data
();
}
int
size
()
const
{
return
descs_
.
size
();
}
private:
std
::
vector
<
cudnnTensorDescriptor_t
>
descs_
;
};
}
// namespace cudnn
}
// namespace math
}
// namespace cuda
}
// namespace lite
...
...
lite/backends/cuda/math/cudnn_softmax.cc
浏览文件 @
ab7b2855
...
...
@@ -54,7 +54,7 @@ bool CudnnSoftmax<T, Ptype>::Create(const operators::SoftmaxParam& param,
const
int
stride_c
=
H
*
stride_h
;
const
int
stride_n
=
C
*
stride_c
;
CUDNN_CHECK
(
cudnnSetTensor4dDescriptorEx
(
bottom_desc_
,
GetCudnnDataType
<
Ptype
>
()
,
cudnn
::
cudnnTypeWrapper
<
T
>::
type
,
N
,
C
,
H
,
...
...
@@ -64,7 +64,7 @@ bool CudnnSoftmax<T, Ptype>::Create(const operators::SoftmaxParam& param,
stride_h
,
stride_w
));
CUDNN_CHECK
(
cudnnSetTensor4dDescriptorEx
(
top_desc_
,
GetCudnnDataType
<
Ptype
>
()
,
cudnn
::
cudnnTypeWrapper
<
T
>::
type
,
N
,
C
,
H
,
...
...
lite/backends/cuda/math/sequence2batch.cu
浏览文件 @
ab7b2855
...
...
@@ -30,17 +30,12 @@ __global__ void CopyMatrixRowsKernel(const T* src,
int
height
,
int
width
,
bool
is_src_index
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
row_id
=
blockDim
.
y
*
blockIdx
.
x
+
idy
;
if
(
row_id
<
height
)
{
int
src_idx
=
is_src_index
?
index
[
row_id
]
:
row_id
;
int
dst_idx
=
is_src_index
?
row_id
:
index
[
row_id
];
const
T
*
src_data
=
src
+
src_idx
*
width
;
T
*
dst_data
=
dst
+
dst_idx
*
width
;
for
(
int
i
=
idx
;
i
<
width
;
i
+=
blockDim
.
x
)
{
dst_data
[
i
]
=
src_data
[
i
];
}
CUDA_KERNEL_LOOP
(
tid
,
height
*
width
)
{
int
row
=
tid
/
width
;
int
idx
=
tid
%
width
;
int
src_row
=
is_src_index
?
index
[
row
]
:
row
;
int
dst_row
=
is_src_index
?
row
:
index
[
row
];
dst
[
dst_row
*
width
+
idx
]
=
src
[
src_row
*
width
+
idx
];
}
}
...
...
@@ -69,9 +64,8 @@ void CopyMatrixRowsFunctor<T>::operator()(
sizeof
(
uint64_t
)
*
index_lod
.
size
(),
IoDirection
::
HtoD
,
stream
);
dim3
threads
(
128
,
8
);
dim3
grids
((
height
+
threads
.
y
-
1
)
/
threads
.
y
);
CopyMatrixRowsKernel
<
T
><<<
grids
,
threads
,
0
,
stream
>>>
(
CopyMatrixRowsKernel
<
T
><<<
CUDA_GET_BLOCKS
(
height
*
width
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
src_data
,
dst_data
,
index_tensor_data
,
height
,
width
,
is_src_index
);
CUDA_POST_KERNEL_CHECK
;
}
...
...
lite/backends/cuda/math/sequence_helper.cu
0 → 100644
浏览文件 @
ab7b2855
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/sequence_helper.h"
#include "lite/backends/cuda/math/utils.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
template
<
typename
Dtype
>
__global__
void
Map2Out
(
Dtype
*
output
,
const
Dtype
*
input
,
const
int
*
map
,
int
count
,
int
lastdim
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
count
)
{
int
seq
=
tid
/
lastdim
;
output
[
map
[
seq
]
*
lastdim
+
tid
%
lastdim
]
=
input
[
tid
];
}
}
template
<
typename
Dtype
>
__global__
void
Map2In
(
Dtype
*
output
,
const
Dtype
*
input
,
const
int
*
map
,
int
count
,
int
lastdim
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
count
)
{
int
seq
=
tid
/
lastdim
;
output
[
tid
]
=
input
[
map
[
seq
]
*
lastdim
+
tid
%
lastdim
];
}
}
template
<
typename
Dtype
>
void
Map2OutFunc
(
const
Dtype
*
input
,
Dtype
*
output
,
int
word_size
,
int
seq_sum
,
cudaStream_t
stream
,
int
*
dev_map_vec
)
{
int
count
=
seq_sum
*
word_size
;
int
block_dim
=
count
;
int
grid_dim
=
1
;
if
(
count
>
1024
)
{
block_dim
=
256
;
grid_dim
=
(
count
+
block_dim
-
1
)
/
block_dim
;
}
Map2Out
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
dev_map_vec
,
count
,
word_size
);
}
template
<
typename
Dtype
>
void
Map2InFunc
(
const
Dtype
*
input
,
Dtype
*
output
,
int
hidden_size
,
int
seq_sum
,
cudaStream_t
stream
,
int
*
dev_map_vec
)
{
int
count
=
seq_sum
*
hidden_size
;
int
block_dim
=
count
;
int
grid_dim
=
1
;
if
(
count
>
1024
)
{
block_dim
=
256
;
grid_dim
=
(
count
+
block_dim
-
1
)
/
block_dim
;
}
Map2In
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
dev_map_vec
,
count
,
hidden_size
);
}
template
<
typename
Dtype
>
void
SeqSortedseqTranseUtil
::
Seq2SortedSeq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
word_size
,
cudaStream_t
stream
)
{
int
seq_sum
=
map_vec_
.
size
();
Map2OutFunc
(
input
,
output
,
word_size
,
seq_sum
,
stream
,
dev_map_vec_
);
}
template
<
typename
Dtype
>
void
SeqSortedseqTranseUtil
::
SortedSeq2Seq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
hidden_size
,
cudaStream_t
stream
)
{
int
seq_sum
=
map_vec_
.
size
();
Map2InFunc
(
input
,
output
,
hidden_size
,
seq_sum
,
stream
,
dev_map_vec_
);
}
bool
SeqSortedseqTranseUtil
::
GetSortedMap
(
const
std
::
vector
<
int
>&
offset_vec
,
cudaStream_t
stream_id
)
{
int
batch_size
=
offset_vec
.
size
()
-
1
;
int
word_sum
=
offset_vec
[
offset_vec
.
size
()
-
1
];
std
::
vector
<
int
>
length_vec
(
batch_size
);
length_index_
.
resize
(
batch_size
);
int
emit_length
=
0
;
if
(
batch_size
==
1
)
{
emit_length
=
offset_vec
[
1
]
-
offset_vec
[
0
];
emit_offset_vec_
.
resize
(
emit_length
+
1
);
for
(
int
i
=
0
;
i
<=
emit_length
;
++
i
)
{
emit_offset_vec_
[
i
]
=
i
;
}
return
false
;
}
int
max_len
=
0
;
for
(
int
i
=
0
;
i
<
offset_vec
.
size
()
-
1
;
++
i
)
{
int
len
=
offset_vec
[
i
+
1
]
-
offset_vec
[
i
];
max_len
=
max_len
>
len
?
max_len
:
len
;
length_vec
[
i
]
=
len
;
length_index_
[
i
]
=
i
;
}
emit_length
=
max_len
;
if
(
max_len
==
1
)
{
emit_offset_vec_
.
resize
(
2
);
emit_offset_vec_
[
0
]
=
0
;
emit_offset_vec_
[
1
]
=
emit_length
*
batch_size
;
return
false
;
}
std
::
stable_sort
(
length_index_
.
begin
(),
length_index_
.
end
(),
[
&
length_vec
](
int
i1
,
int
i2
)
{
return
length_vec
[
i1
]
>
length_vec
[
i2
];
});
emit_offset_vec_
.
resize
(
max_len
+
1
);
map_vec_
.
resize
(
word_sum
);
if
(
word_sum
>
dev_map_vec_length_
)
{
if
(
dev_map_vec_
!=
nullptr
)
{
TargetWrapperCuda
::
Free
(
static_cast
<
void
*>
(
dev_map_vec_
));
}
dev_map_vec_
=
static_cast
<
int
*>
(
TargetWrapperCuda
::
Malloc
(
sizeof
(
int
)
*
word_sum
));
dev_map_vec_length_
=
word_sum
;
}
int
target_word_id
=
0
;
std
::
vector
<
int
>
length_vec_cnt
=
length_vec
;
int
last_batch_size
=
batch_size
;
for
(
int
word_id_in_seq
=
0
;
word_id_in_seq
<
max_len
;
word_id_in_seq
++
)
{
emit_offset_vec_
[
word_id_in_seq
]
=
target_word_id
;
for
(
int
batch_id
=
0
;
batch_id
<
last_batch_size
;
batch_id
++
)
{
int
old_batch_id
=
length_index_
[
batch_id
];
if
(
length_vec_cnt
[
old_batch_id
]
>
0
)
{
int
inner_word_id_in_seq
=
word_id_in_seq
;
if
(
is_reverse_
)
{
inner_word_id_in_seq
=
length_vec
[
old_batch_id
]
-
1
-
word_id_in_seq
;
}
int
old_word_id
=
offset_vec
[
old_batch_id
]
+
inner_word_id_in_seq
;
map_vec_
[
old_word_id
]
=
target_word_id
;
length_vec_cnt
[
old_batch_id
]
--
;
target_word_id
++
;
}
else
{
last_batch_size
--
;
break
;
}
}
}
TargetWrapperCuda
::
MemcpyAsync
(
dev_map_vec_
,
map_vec_
.
data
(),
sizeof
(
int
)
*
word_sum
,
IoDirection
::
HtoD
,
stream_id
);
emit_offset_vec_
[
max_len
]
=
word_sum
;
emit_length_
=
emit_length
;
return
true
;
}
template
void
SeqSortedseqTranseUtil
::
Seq2SortedSeq
(
const
float
*
input
,
float
*
output
,
int
word_size
,
cudaStream_t
stream
);
template
void
SeqSortedseqTranseUtil
::
SortedSeq2Seq
(
const
float
*
input
,
float
*
output
,
int
hidden_size
,
cudaStream_t
stream
);
template
void
SeqSortedseqTranseUtil
::
Seq2SortedSeq
(
const
half
*
input
,
half
*
output
,
int
word_size
,
cudaStream_t
stream
);
template
void
SeqSortedseqTranseUtil
::
SortedSeq2Seq
(
const
half
*
input
,
half
*
output
,
int
hidden_size
,
cudaStream_t
stream
);
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/backends/cuda/math/sequence_helper.h
0 → 100644
浏览文件 @
ab7b2855
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include "lite/backends/cuda/target_wrapper.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
class
SeqSortedseqTranseUtil
{
public:
explicit
SeqSortedseqTranseUtil
(
bool
is_reverse
=
false
,
bool
is_bi
=
false
)
:
is_reverse_
(
is_reverse
),
is_bi_
(
is_bi
),
dev_map_vec_
(
nullptr
),
dev_map_vec_length_
(
0
)
{}
~
SeqSortedseqTranseUtil
()
{
if
(
dev_map_vec_
!=
nullptr
)
{
TargetWrapperCuda
::
Free
(
static_cast
<
void
*>
(
dev_map_vec_
));
}
}
std
::
vector
<
int
>&
GetLengthIndex
()
{
return
length_index_
;
}
std
::
vector
<
int
>&
GetEmitOffsetVec
()
{
return
emit_offset_vec_
;
}
std
::
vector
<
int
>&
GetMapVec
()
{
return
map_vec_
;
}
int
*
GetDevMapVec
()
{
return
dev_map_vec_
;
}
int
GetEmitLength
()
{
return
emit_length_
;
}
template
<
typename
Dtype
>
void
Seq2SortedSeq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
word_size
,
cudaStream_t
stream
);
template
<
typename
Dtype
>
void
SortedSeq2Seq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
hidden_size
,
cudaStream_t
stream
);
bool
GetSortedMap
(
const
std
::
vector
<
int
>&
offset_vec
,
cudaStream_t
stream_id
);
private:
std
::
vector
<
int
>
length_index_
;
std
::
vector
<
int
>
emit_offset_vec_
;
std
::
vector
<
int
>
map_vec_
;
int
emit_length_
;
bool
is_reverse_
;
bool
is_bi_
;
int
*
dev_map_vec_
;
int
dev_map_vec_length_
;
};
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/core/CMakeLists.txt
浏览文件 @
ab7b2855
...
...
@@ -133,7 +133,7 @@ lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper)
lite_cc_library
(
program SRCS program.cc
DEPS op kernel model_parser
${
ops
}
${
cpp_wrapper
}
PROFILE_DEPS lite_profiler
CUDA_DEPS nvtx_wrapper
)
CUDA_DEPS nvtx_wrapper
cuda_type_trans
)
if
(
NOT LITE_ON_TINY_PUBLISH
)
lite_cc_library
(
optimizer SRCS optimizer.cc DEPS mir_pass_manager model_parser program
)
...
...
lite/kernels/cuda/gru_compute.cu
浏览文件 @
ab7b2855
...
...
@@ -14,6 +14,7 @@
#include "lite/kernels/cuda/gru_compute.h"
#include <string>
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/bias.h"
...
...
@@ -273,6 +274,8 @@ void GRUCompute<T, PType>::Run() {
auto
&
param
=
this
->
template
Param
<
param_t
>();
auto
*
input
=
param
.
input
;
T
*
x_data
=
const_cast
<
lite
::
Tensor
*>
(
input
)
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
lite
::
Tensor
*
h0
{
nullptr
};
if
(
param
.
h0
)
{
h0
=
const_cast
<
lite
::
Tensor
*>
(
param
.
h0
);
...
...
@@ -289,7 +292,7 @@ void GRUCompute<T, PType>::Run() {
lite
::
Tensor
*
hidden
=
param
.
hidden
;
T
*
batch_reset_hidden_prev_data
=
batch_reset_hidden_prev
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
T
*
out_data
=
hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
T
*
batch_gate_data
=
batch_gate
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
T
*
batch_hidden_data
=
batch_hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
bool
is_reverse
=
param
.
is_reverse
;
...
...
@@ -300,14 +303,28 @@ void GRUCompute<T, PType>::Run() {
auto
hidden_dims
=
hidden
->
dims
();
int
frame_size
=
hidden_dims
[
1
];
lite
::
cuda
::
math
::
LoDTensor2BatchFunctor
<
T
>
batch_func
;
batch_func
(
*
input
,
batch_gate
,
is_reverse
,
stream
);
LoD
offset_vec_vec
=
input
->
lod
();
std
::
vector
<
int
>
offset
(
offset_vec_vec
[
offset_vec_vec
.
size
()
-
1
].
size
());
for
(
size_t
i
=
0
;
i
<
offset_vec_vec
[
offset_vec_vec
.
size
()
-
1
].
size
();
++
i
)
{
offset
[
i
]
=
static_cast
<
int
>
(
offset_vec_vec
[
offset_vec_vec
.
size
()
-
1
][
i
]);
}
bool
need_process
=
seq_utils_
.
GetSortedMap
(
offset
,
stream
);
int
emit_length
=
seq_utils_
.
GetEmitOffsetVec
().
size
()
-
1
;
auto
emit_offset_vec
=
seq_utils_
.
GetEmitOffsetVec
();
if
(
need_process
)
{
seq_utils_
.
Seq2SortedSeq
(
input
->
template
data
<
T
>(),
batch_gate_data
,
3
*
frame_size
,
stream
);
x_data
=
batch_gate_data
;
out_data
=
batch_hidden_data
;
}
if
(
bias
)
{
// TODO(wilber): validate when bias is not nullptr
lite
::
cuda
::
math
::
RowwiseAdd
<
T
>
add_bias
;
add_bias
(
batch_gate
_data
,
add_bias
(
x
_data
,
bias
->
template
data
<
T
>(),
batch_gate
_data
,
x
_data
,
frame_size
,
batch_gate
->
numel
(),
stream
);
...
...
@@ -320,6 +337,7 @@ void GRUCompute<T, PType>::Run() {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
// TODO(wilber): validate when h0 is not nullptr
ordered_h0_
.
Resize
(
h0
->
dims
());
lite
::
cuda
::
math
::
CopyMatrixRowsFunctor
<
T
>
row_shuffle
;
row_shuffle
(
*
h0
,
&
ordered_h0_
,
batch_gate
->
lod
()[
2
],
true
,
stream
);
...
...
@@ -327,15 +345,13 @@ void GRUCompute<T, PType>::Run() {
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
for
(
size_t
n
=
0
;
n
<
num_batch
;
++
n
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
for
(
size_t
n
=
0
;
n
<
emit_length
;
++
n
)
{
int
bstart
=
emit_offset_vec
[
n
];
int
bend
=
emit_offset_vec
[
n
+
1
];
int
cur_batch_size
=
bend
-
bstart
;
gru_value
.
output_value
=
batch_hidden
_data
+
bstart
*
frame_size
;
gru_value
.
gate_value
=
batch_gate
_data
+
bstart
*
frame_size
*
3
;
gru_value
.
output_value
=
out
_data
+
bstart
*
frame_size
;
gru_value
.
gate_value
=
x
_data
+
bstart
*
frame_size
*
3
;
gru_value
.
reset_output_value
=
batch_reset_hidden_prev_data
+
bstart
*
frame_size
;
...
...
@@ -349,10 +365,13 @@ void GRUCompute<T, PType>::Run() {
&
context
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
lite
::
cuda
::
math
::
Batch2LoDTensorFunctor
<
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
*
batch_hidden
,
hidden
,
stream
);
if
(
need_process
)
{
seq_utils_
.
SortedSeq2Seq
(
batch_hidden_data
,
hidden
->
mutable_data
<
T
>
(
TARGET
(
kCUDA
)),
frame_size
,
stream
);
}
hidden
->
set_lod
(
input
->
lod
());
}
}
// namespace cuda
...
...
lite/kernels/cuda/gru_compute.h
浏览文件 @
ab7b2855
...
...
@@ -16,6 +16,7 @@
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/backends/cuda/math/sequence_helper.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
...
...
@@ -38,6 +39,7 @@ class GRUCompute : public KernelLite<TARGET(kCUDA), PType> {
private:
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>>
gemm_impl_
{
nullptr
};
lite
::
Tensor
ordered_h0_
;
lite
::
cuda
::
math
::
SeqSortedseqTranseUtil
seq_utils_
;
};
}
// namespace cuda
...
...
lite/operators/softmax_op.cc
浏览文件 @
ab7b2855
...
...
@@ -55,6 +55,8 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if
(
opdesc
.
HasAttr
(
"use_cudnn"
))
{
param_
.
use_cudnn
=
opdesc
.
GetAttr
<
bool
>
(
"use_cudnn"
);
}
// TODO(wilber): use cudnn default when compile with cuda.
param_
.
use_cudnn
=
true
;
return
true
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录