Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
a105bbdf
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a105bbdf
编写于
7月 24, 2020
作者:
W
Wilber
提交者:
GitHub
7月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CUDA] Support model run correctly. (#3975)
上级
e45e6dc6
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
121 addition
and
90 deletion
+121
-90
lite/backends/cuda/math/gru_forward.h
lite/backends/cuda/math/gru_forward.h
+9
-1
lite/backends/cuda/math/scale.cu
lite/backends/cuda/math/scale.cu
+4
-11
lite/backends/cuda/math/scale.h
lite/backends/cuda/math/scale.h
+1
-2
lite/backends/cuda/math/sequence2batch.cu
lite/backends/cuda/math/sequence2batch.cu
+2
-2
lite/backends/cuda/math/sequence2batch.h
lite/backends/cuda/math/sequence2batch.h
+10
-10
lite/backends/cuda/math/sequence_padding.cu
lite/backends/cuda/math/sequence_padding.cu
+2
-4
lite/kernels/cuda/assign_value_compute.cu
lite/kernels/cuda/assign_value_compute.cu
+1
-1
lite/kernels/cuda/dropout_compute.cc
lite/kernels/cuda/dropout_compute.cc
+4
-1
lite/kernels/cuda/gru_compute.cu
lite/kernels/cuda/gru_compute.cu
+4
-7
lite/kernels/cuda/scale_compute.cc
lite/kernels/cuda/scale_compute.cc
+5
-2
lite/kernels/cuda/sequence_mask_compute.cu
lite/kernels/cuda/sequence_mask_compute.cu
+8
-5
lite/kernels/cuda/sequence_pad_compute.cu
lite/kernels/cuda/sequence_pad_compute.cu
+12
-2
lite/kernels/cuda/sequence_unpad_compute.cu
lite/kernels/cuda/sequence_unpad_compute.cu
+33
-1
lite/kernels/cuda/sequence_unpad_compute.h
lite/kernels/cuda/sequence_unpad_compute.h
+1
-0
lite/kernels/cuda/var_conv_2d_compute.cu
lite/kernels/cuda/var_conv_2d_compute.cu
+5
-0
lite/operators/gru_op.cc
lite/operators/gru_op.cc
+2
-3
lite/operators/sequence_pad_op.cc
lite/operators/sequence_pad_op.cc
+7
-6
lite/operators/sequence_unpad_op.cc
lite/operators/sequence_unpad_op.cc
+1
-26
lite/operators/var_conv_2d_op.cc
lite/operators/var_conv_2d_op.cc
+10
-6
未找到文件。
lite/backends/cuda/math/gru_forward.h
浏览文件 @
a105bbdf
...
...
@@ -30,9 +30,16 @@ namespace lite {
namespace
cuda
{
namespace
math
{
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
template
<
typename
Dtype
>
inline
__device__
Dtype
Sigmoid
(
const
Dtype
a
)
{
return
static_cast
<
Dtype
>
(
1.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
-
a
));
const
Dtype
min
=
SIGMOID_THRESHOLD_MIN
;
const
Dtype
max
=
SIGMOID_THRESHOLD_MAX
;
Dtype
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
Dtype
>
(
1.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
-
tmp
));
}
template
<
>
...
...
@@ -63,6 +70,7 @@ inline __device__ half ReLU(const half a) {
template
<
typename
Dtype
>
inline
__device__
Dtype
Tanh
(
const
Dtype
a
)
{
Dtype
tmp
=
static_cast
<
Dtype
>
(
-
2.0
)
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
static_cast
<
Dtype
>
(
2.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
tmp
)))
-
static_cast
<
Dtype
>
(
1.0
);
}
...
...
lite/backends/cuda/math/scale.cu
浏览文件 @
a105bbdf
...
...
@@ -22,10 +22,6 @@ namespace lite {
namespace
cuda
{
namespace
math
{
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template
<
typename
T
>
__global__
void
scale_kernel
(
int
count
,
const
T
*
in_data
,
...
...
@@ -48,7 +44,6 @@ __global__ void scale_kernel(int count,
template
<
typename
T
>
__global__
void
scale_kernel
(
int
count
,
const
T
*
in_data
,
T
*
out_data
,
const
T
scale
,
const
T
bias
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
CUDA_KERNEL_LOOP
(
tid
,
count
)
{
out_data
[
tid
]
=
scale
*
in_data
[
tid
]
+
bias
;
}
}
...
...
@@ -133,12 +128,11 @@ void fp32_scale_nhwc(int num,
}
template
<
typename
T
>
void
scale
(
int
num
,
const
T
*
in
,
T
*
out
,
T
scale
,
cudaStream_t
stream
,
T
bias
)
{
void
scale
(
int
num
,
const
T
*
in
,
T
*
out
,
T
scale
,
T
bias
,
cudaStream_t
stream
)
{
int
thread
=
256
;
int
block
=
(
num
+
thread
-
1
)
/
thread
;
scale_kernel
<<<
block
,
thread
,
0
,
stream
>>>
(
num
,
in
,
out
,
scale
,
bias
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
std
::
cout
<<
cudaGetErrorString
(
error
);
CUDA_POST_KERNEL_CHECK
;
}
template
<
typename
T
>
...
...
@@ -146,11 +140,10 @@ void scale(int num, const T* in, T* out, T scale, T bias) {
int
thread
=
256
;
int
block
=
(
num
+
thread
-
1
)
/
thread
;
scale_kernel
<<<
block
,
thread
>>>
(
num
,
in
,
out
,
scale
,
bias
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
std
::
cout
<<
cudaGetErrorString
(
error
);
CUDA_POST_KERNEL_CHECK
;
}
template
void
scale
(
int
num
,
const
float
*
,
float
*
,
float
,
cudaStream_t
,
floa
t
);
template
void
scale
(
int
num
,
const
float
*
,
float
*
,
float
,
float
,
cudaStream_
t
);
template
void
scale
(
int
num
,
const
float
*
,
float
*
,
float
,
float
);
}
// namespace math
...
...
lite/backends/cuda/math/scale.h
浏览文件 @
a105bbdf
...
...
@@ -32,8 +32,7 @@ void fp32_scale_nhwc(int num,
cudaStream_t
stream
);
template
<
typename
T
>
void
scale
(
int
num
,
const
T
*
in
,
T
*
out
,
T
scale
,
cudaStream_t
stream
,
T
bias
=
0
);
void
scale
(
int
num
,
const
T
*
in
,
T
*
out
,
T
scale
,
T
bias
,
cudaStream_t
stream
);
template
<
typename
T
>
void
scale
(
int
num
,
const
T
*
in
,
T
*
out
,
T
scale
,
T
bias
=
0
);
...
...
lite/backends/cuda/math/sequence2batch.cu
浏览文件 @
a105bbdf
...
...
@@ -32,7 +32,7 @@ __global__ void CopyMatrixRowsKernel(const T* src,
bool
is_src_index
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
row_id
=
blockDim
.
y
*
gridDim
.
x
+
idy
;
int
row_id
=
blockDim
.
y
*
blockIdx
.
x
+
idy
;
if
(
row_id
<
height
)
{
int
src_idx
=
is_src_index
?
index
[
row_id
]
:
row_id
;
int
dst_idx
=
is_src_index
?
row_id
:
index
[
row_id
];
...
...
@@ -72,7 +72,7 @@ void CopyMatrixRowsFunctor<T>::operator()(
dim3
threads
(
128
,
8
);
dim3
grids
((
height
+
threads
.
y
-
1
)
/
threads
.
y
);
CopyMatrixRowsKernel
<
T
><<<
grids
,
threads
,
0
,
stream
>>>
(
src_data
,
dst_data
,
index_tensor_data
,
height
,
width
,
true
);
src_data
,
dst_data
,
index_tensor_data
,
height
,
width
,
is_src_index
);
CUDA_POST_KERNEL_CHECK
;
}
...
...
lite/backends/cuda/math/sequence2batch.h
浏览文件 @
a105bbdf
...
...
@@ -53,11 +53,11 @@ class LoDTensor2BatchFunctor {
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
struct
SeqInfo
{
SeqInfo
(
size_t
start
,
size_t
length
,
size_t
seq_idx
)
:
start
_
(
start
),
length_
(
length
),
seq_idx_
(
seq_idx
)
{}
size_t
start
_
;
size_t
length
_
;
size_t
seq_idx
_
;
SeqInfo
(
size_t
start
_val
,
size_t
len_val
,
size_t
seq_val
)
:
start
(
start_val
),
length
(
len_val
),
seq_idx
(
seq_val
)
{}
size_t
start
;
size_t
length
;
size_t
seq_idx
;
};
public:
...
...
@@ -76,7 +76,7 @@ class LoDTensor2BatchFunctor {
}
std
::
sort
(
seq_info
.
begin
(),
seq_info
.
end
(),
[](
SeqInfo
a
,
SeqInfo
b
)
{
return
a
.
length
_
>
b
.
length_
;
return
a
.
length
>
b
.
length
;
});
// Calculate the start position of each batch.
...
...
@@ -106,7 +106,7 @@ class LoDTensor2BatchFunctor {
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
// batch_lods[0] is the start positions for batch LoDTensor
size_t
max_seqlen
=
seq_info
[
0
].
length
_
;
size_t
max_seqlen
=
seq_info
[
0
].
length
;
batch_lods
[
0
].
resize
(
max_seqlen
+
1
);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
...
...
@@ -119,8 +119,8 @@ class LoDTensor2BatchFunctor {
for
(
size_t
n
=
0
;
n
<
max_seqlen
;
++
n
)
{
size_t
batch_id
=
batch_starts
[
n
];
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
size_t
seq_len
=
seq_info
[
i
].
length
_
;
size_t
start
=
seq_info
[
i
].
start
_
;
size_t
seq_len
=
seq_info
[
i
].
length
;
size_t
start
=
seq_info
[
i
].
start
;
if
(
n
<
seq_len
)
{
seq2batch_idx
[
batch_id
]
=
is_reverse
?
start
+
seq_len
-
1
-
n
:
start
+
n
;
...
...
@@ -133,7 +133,7 @@ class LoDTensor2BatchFunctor {
}
auto
*
seq_order
=
batch_lods
[
2
].
data
();
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
seq_order
[
i
]
=
seq_info
[
i
].
seq_idx
_
;
seq_order
[
i
]
=
seq_info
[
i
].
seq_idx
;
}
batch_tensor
->
set_lod
(
batch_lods
);
...
...
lite/backends/cuda/math/sequence_padding.cu
浏览文件 @
a105bbdf
...
...
@@ -86,8 +86,7 @@ void SequencePadding(T* pad_data,
seq_num
,
pad_seq_len
,
step_width
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
CUDA_POST_KERNEL_CHECK
;
}
template
<
typename
T
>
...
...
@@ -120,8 +119,7 @@ void SequenceUnpadding(T* seq_data,
seq_num
,
pad_seq_len
,
step_width
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
CUDA_POST_KERNEL_CHECK
;
}
template
void
SequencePadding
(
float
*
pad_data
,
...
...
lite/kernels/cuda/assign_value_compute.cu
浏览文件 @
a105bbdf
...
...
@@ -68,7 +68,7 @@ void AssignValueCompute::Run() {
REGISTER_LITE_KERNEL
(
assign_value
,
kCUDA
,
k
Any
,
k
Float
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
AssignValueCompute
,
def
)
...
...
lite/kernels/cuda/dropout_compute.cc
浏览文件 @
a105bbdf
...
...
@@ -23,6 +23,9 @@ namespace cuda {
void
DropoutCompute
::
Run
()
{
auto
&
param
=
Param
<
operators
::
DropoutParam
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
const
float
*
x_data
=
param
.
x
->
data
<
float
>
();
float
*
out_data
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
int
num
=
param
.
x
->
dims
().
production
();
...
...
@@ -31,7 +34,7 @@ void DropoutCompute::Run() {
if
(
param
.
dropout_implementation
==
"downgrade_in_infer"
)
{
scale
=
1.0
f
-
prob_data
;
}
lite
::
cuda
::
math
::
scale
(
num
,
x_data
,
out_data
,
scale
,
0
);
lite
::
cuda
::
math
::
scale
(
num
,
x_data
,
out_data
,
scale
,
0
.
f
,
stream
);
}
}
// namespace cuda
...
...
lite/kernels/cuda/gru_compute.cu
浏览文件 @
a105bbdf
...
...
@@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/gru_compute.h"
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
...
...
@@ -19,7 +21,6 @@
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/gru_compute.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -133,7 +134,6 @@ struct GRUUnitFunctor {
value
.
gate_value
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardResetOutput
<
T
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
...
...
@@ -143,7 +143,7 @@ struct GRUUnitFunctor {
frame_size
,
batch_size
,
active_gate
,
batch_size
=
=
1
);
batch_size
!
=
1
);
CUDA_POST_KERNEL_CHECK
;
if
(
value
.
prev_out_value
)
{
...
...
@@ -163,7 +163,6 @@ struct GRUUnitFunctor {
value
.
gate_value
+
frame_size
*
2
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardFinalOutput
<
T
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
...
...
@@ -173,7 +172,7 @@ struct GRUUnitFunctor {
batch_size
,
active_node
,
origin_mode
,
batch_size
=
=
1
);
batch_size
!
=
1
);
CUDA_POST_KERNEL_CHECK
;
}
};
...
...
@@ -218,7 +217,6 @@ struct GRUUnitFunctor<half> {
value
.
gate_value
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardResetOutput
<
half
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
...
...
@@ -248,7 +246,6 @@ struct GRUUnitFunctor<half> {
value
.
gate_value
+
frame_size
*
2
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardFinalOutput
<
half
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
...
...
lite/kernels/cuda/scale_compute.cc
浏览文件 @
a105bbdf
...
...
@@ -23,8 +23,11 @@ namespace cuda {
void
ScaleCompute
::
Run
()
{
auto
&
param
=
Param
<
operators
::
ScaleParam
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
const
float
*
x_data
=
param
.
x
->
data
<
float
>
();
float
*
output_data
=
param
.
output
->
mutable_data
<
float
>
();
float
*
output_data
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
)
);
DDim
x_dims
=
param
.
x
->
dims
();
bool
bias_after_scale
=
param
.
bias_after_scale
;
float
scale
=
param
.
scale
;
...
...
@@ -33,7 +36,7 @@ void ScaleCompute::Run() {
bias
*=
scale
;
}
lite
::
cuda
::
math
::
scale
(
x_dims
.
production
(),
x_data
,
output_data
,
scale
,
bias
);
x_dims
.
production
(),
x_data
,
output_data
,
scale
,
bias
,
stream
);
}
}
// namespace cuda
...
...
lite/kernels/cuda/sequence_mask_compute.cu
浏览文件 @
a105bbdf
...
...
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_mask_compute.h"
#include <thrust/device_ptr.h>
#include <thrust/functional.h>
#include <thrust/reduce.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sequence_mask_compute.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -44,7 +44,7 @@ void SequenceMaskCompute<T, Ptype>::Run() {
auto
stream
=
ctx
.
exec_stream
();
const
auto
*
x
=
param
.
X
;
auto
*
x_data
=
x
->
template
data
<
int64_t
>();
const
int64_t
*
x_data
=
x
->
template
data
<
int64_t
>();
auto
*
y
=
param
.
Y
;
int
maxlen
=
param
.
maxlen
;
...
...
@@ -57,8 +57,11 @@ void SequenceMaskCompute<T, Ptype>::Run() {
}
if
(
maxlen
<
0
)
{
maxlen
=
thrust
::
reduce
(
x_data
,
x_data
+
x
->
numel
(),
0
,
thrust
::
maximum
<
int64_t
>
());
maxlen
=
static_cast
<
int
>
(
thrust
::
reduce
(
thrust
::
device_pointer_cast
(
x_data
),
thrust
::
device_pointer_cast
(
x_data
)
+
x
->
numel
(),
static_cast
<
int64_t
>
(
0
),
thrust
::
maximum
<
int64_t
>
()));
}
auto
y_dim
=
x
->
dims
().
Vectorize
();
...
...
lite/kernels/cuda/sequence_pad_compute.cu
浏览文件 @
a105bbdf
...
...
@@ -32,9 +32,19 @@ void SequencePadCompute<T, Ptype>::Run() {
const
auto
*
pad_value
=
param
.
PadValue
;
auto
*
out
=
param
.
Out
;
auto
*
len_t
=
param
.
Length
;
int
padded_length
=
param
.
padded_length
;
int
seq_num
=
x
->
lod
()[
0
].
size
()
-
1
;
int
padded_length
;
if
(
param
.
padded_length
==
-
1
)
{
int
max_seq_len
=
0
;
for
(
int
i
=
0
;
i
<
seq_num
;
++
i
)
{
max_seq_len
=
std
::
max
(
max_seq_len
,
static_cast
<
int
>
(
x
->
lod
()[
0
][
i
+
1
]
-
x
->
lod
()[
0
][
i
]));
}
padded_length
=
max_seq_len
;
}
else
{
padded_length
=
param
.
padded_length
;
}
int
max_seq_len
=
0
;
int
step_width
=
x
->
numel
()
/
x
->
dims
()[
0
];
...
...
lite/kernels/cuda/sequence_unpad_compute.cu
浏览文件 @
a105bbdf
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/math/sequence_padding.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
...
...
@@ -29,8 +30,39 @@ void SequenceUnpadCompute<T, Ptype>::Run() {
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
auto
x_dims
=
param
.
X
->
dims
();
auto
len_dims
=
param
.
Length
->
dims
();
auto
*
seq_len_ptr
=
param
.
Length
->
template
data
<
int64_t
>();
seq_len_cpu_
.
Resize
(
param
.
Length
->
dims
());
TargetWrapperCuda
::
MemcpyAsync
(
seq_len_cpu_
.
mutable_data
<
int64_t
>
(),
seq_len_ptr
,
sizeof
(
int64_t
)
*
param
.
Length
->
numel
(),
IoDirection
::
DtoH
,
stream
);
TargetWrapperCuda
::
StreamSync
(
stream
);
int64_t
batch_size
=
len_dims
[
0
];
std
::
vector
<
uint64_t
>
out_lod0
(
batch_size
+
1
,
0
);
for
(
int64_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
out_lod0
[
i
+
1
]
=
out_lod0
[
i
]
+
seq_len_cpu_
.
data
<
int64_t
>
()[
i
];
}
paddle
::
lite
::
LoD
out_lod
;
out_lod
.
push_back
(
out_lod0
);
int64_t
out_dim0
=
out_lod0
.
back
();
std
::
vector
<
int64_t
>
out_dims
{
out_dim0
};
if
(
x_dims
.
size
()
==
2
)
{
out_dims
.
push_back
(
1
);
}
else
{
for
(
size_t
i
=
2
;
i
<
x_dims
.
size
();
++
i
)
{
out_dims
.
push_back
(
x_dims
[
i
]);
}
}
param
.
Out
->
Resize
(
out_dims
);
param
.
Out
->
set_lod
(
out_lod
);
const
auto
*
pad_tensor
=
param
.
X
;
const
auto
*
len_t
=
param
.
Length
;
auto
*
seq_tensor
=
param
.
Out
;
int
padded_length
=
pad_tensor
->
dims
()[
1
];
...
...
lite/kernels/cuda/sequence_unpad_compute.h
浏览文件 @
a105bbdf
...
...
@@ -31,6 +31,7 @@ class SequenceUnpadCompute : public KernelLite<TARGET(kCUDA), Ptype> {
private:
lite
::
Tensor
seq_offsets_
;
lite
::
Tensor
seq_len_cpu_
;
std
::
vector
<
size_t
>
seq_offsets_vec_
;
};
...
...
lite/kernels/cuda/var_conv_2d_compute.cu
浏览文件 @
a105bbdf
...
...
@@ -184,6 +184,8 @@ using VarConvFp16 =
REGISTER_LITE_KERNEL
(
var_conv_2d
,
kCUDA
,
kFloat
,
kNCHW
,
VarConvFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"W"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"COLUMN"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"ROW"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Col"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
...
...
@@ -191,6 +193,9 @@ REGISTER_LITE_KERNEL(var_conv_2d, kCUDA, kFloat, kNCHW, VarConvFp32, def)
REGISTER_LITE_KERNEL
(
var_conv_2d
,
kCUDA
,
kFP16
,
kNCHW
,
VarConvFp16
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"W"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"COLUMN"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"ROW"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Col"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/operators/gru_op.cc
浏览文件 @
a105bbdf
...
...
@@ -75,9 +75,8 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto
batch_reset_hidden_prev
=
op_desc
.
Output
(
"BatchResetHiddenPrev"
).
front
();
auto
batch_hidden
=
op_desc
.
Output
(
"BatchHidden"
).
front
();
auto
hidden
=
op_desc
.
Output
(
"Hidden"
).
front
();
param_
.
input
=
scope
->
FindVar
(
input
)
->
GetMutable
<
lite
::
Tensor
>
();
if
(
op_desc
.
Input
(
"H0"
).
size
())
{
if
(
!
op_desc
.
Input
(
"H0"
).
empty
())
{
auto
h0
=
op_desc
.
Input
(
"H0"
).
front
();
param_
.
h0
=
scope
->
FindVar
(
h0
)
->
GetMutable
<
lite
::
Tensor
>
();
}
...
...
@@ -90,7 +89,7 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
scope
->
FindVar
(
batch_hidden
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
hidden
=
scope
->
FindVar
(
hidden
)
->
GetMutable
<
lite
::
Tensor
>
();
if
(
op_desc
.
HasInput
(
"Bias"
))
{
if
(
!
op_desc
.
Input
(
"Bias"
).
empty
(
))
{
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
lite
::
Tensor
>
();
}
...
...
lite/operators/sequence_pad_op.cc
浏览文件 @
a105bbdf
...
...
@@ -61,18 +61,19 @@ bool SequencePadOp::InferShapeImpl() const {
max_seq_len
=
std
::
max
(
max_seq_len
,
static_cast
<
int
>
(
x_lod_0
[
i
+
1
]
-
x_lod_0
[
i
]));
}
if
(
param_
.
padded_length
==
-
1
)
{
param_
.
padded_length
=
max_seq_len
;
int
real_padded_length
=
param_
.
padded_length
;
if
(
real_padded_length
==
-
1
)
{
real_padded_length
=
max_seq_len
;
}
CHECK_GE
(
param_
.
padded_length
,
max_seq_len
)
CHECK_GE
(
real_
padded_length
,
max_seq_len
)
<<
"The SequencePadOp Attr(padded_length) should be greater than or "
"equal to the length of the longest original sequence. But the "
"padded_length we received is "
<<
param_
.
padded_length
<<
real_
padded_length
<<
", the length of the longest original sequence is "
<<
max_seq_len
;
int
out_dim_0
=
seq_num
;
std
::
vector
<
int64_t
>
out_dims_vec
{
out_dim_0
,
param_
.
padded_length
};
std
::
vector
<
int64_t
>
out_dims_vec
{
out_dim_0
,
real_
padded_length
};
std
::
vector
<
int64_t
>
len_dims_vec
{
out_dim_0
};
auto
time_step_dims_vec
=
time_step_dims
.
Vectorize
();
out_dims_vec
.
insert
(
...
...
@@ -87,7 +88,7 @@ bool SequencePadOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
PadValue
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"PadValue"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
Length
=
scope
->
FindVar
(
opdesc
.
In
put
(
"Length"
).
front
())
param_
.
Length
=
scope
->
FindVar
(
opdesc
.
Out
put
(
"Length"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
Out
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
...
...
lite/operators/sequence_unpad_op.cc
浏览文件 @
a105bbdf
...
...
@@ -32,32 +32,7 @@ bool SequenceUnpadOp::CheckShape() const {
return
true
;
}
bool
SequenceUnpadOp
::
InferShapeImpl
()
const
{
auto
x_dims
=
param_
.
X
->
dims
();
auto
len_dims
=
param_
.
Length
->
dims
();
auto
*
seq_len_ptr
=
param_
.
Length
->
data
<
int64_t
>
();
int64_t
batch_size
=
len_dims
[
0
];
std
::
vector
<
uint64_t
>
out_lod0
(
batch_size
+
1
,
0
);
for
(
int64_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
out_lod0
[
i
+
1
]
=
out_lod0
[
i
]
+
seq_len_ptr
[
i
];
}
paddle
::
lite
::
LoD
out_lod
;
out_lod
.
push_back
(
out_lod0
);
int64_t
out_dim0
=
out_lod0
.
back
();
std
::
vector
<
int64_t
>
out_dims
{
out_dim0
};
if
(
x_dims
.
size
()
==
2
)
{
out_dims
.
push_back
(
1
);
}
else
{
for
(
size_t
i
=
2
;
i
<
x_dims
.
size
();
++
i
)
{
out_dims
.
push_back
(
x_dims
[
i
]);
}
}
param_
.
Out
->
Resize
(
out_dims
);
param_
.
Out
->
set_lod
(
out_lod
);
return
true
;
}
bool
SequenceUnpadOp
::
InferShapeImpl
()
const
{
return
true
;
}
bool
SequenceUnpadOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
...
...
lite/operators/var_conv_2d_op.cc
浏览文件 @
a105bbdf
...
...
@@ -26,10 +26,16 @@ bool VarConv2dOp::InferShapeImpl() const { return true; }
bool
VarConv2dOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
X
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
// param_.ROW = const_cast<lite::Tensor *>(
// &scope->FindVar(opdesc.Input("ROW").front())->Get<lite::Tensor>());
// param_.COLUMN = const_cast<lite::Tensor *>(
// &scope->FindVar(opdesc.Input("COLUMN").front())->Get<lite::Tensor>());
if
(
opdesc
.
HasInput
(
"ROW"
)
&&
!
opdesc
.
Input
(
"ROW"
).
empty
())
{
param_
.
ROW
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"ROW"
).
front
())
->
Get
<
lite
::
Tensor
>
());
CHECK
(
param_
.
ROW
)
<<
"Input(ROW) of VarConv2dOP should not be null."
;
}
if
(
opdesc
.
HasInput
(
"COLUMN"
)
&&
!
opdesc
.
Input
(
"COLUMN"
).
empty
())
{
param_
.
COLUMN
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"COLUMN"
).
front
())
->
Get
<
lite
::
Tensor
>
());
CHECK
(
param_
.
COLUMN
)
<<
"Input(COLUMN) of VarConv2dOP should not be null."
;
}
param_
.
W
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"W"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
Out
=
...
...
@@ -37,8 +43,6 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_
.
Col
=
scope
->
FindVar
(
opdesc
.
Output
(
"Col"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
param_
.
X
)
<<
"X(Input) of VarConv2dOP should not be null."
;
// CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null.";
// CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null.";
CHECK
(
param_
.
W
)
<<
"W(Input) of VarConv2dOP should not be null."
;
CHECK
(
param_
.
Out
)
<<
"Out(Output) of VarConv2dOP should not be null."
;
CHECK
(
param_
.
Col
)
<<
"Col(Output) of VarConv2dOP should not be null."
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录