Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
ef5dbd1d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ef5dbd1d
编写于
11月 22, 2019
作者:
P
Pei Yang
提交者:
GitHub
11月 22, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add search_group_padding cuda kernel, test=develop (#2472)
上级
874a5af4
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
327 addition
and
1 deletion
+327
-1
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+3
-1
lite/kernels/cuda/search_group_padding_compute.cu
lite/kernels/cuda/search_group_padding_compute.cu
+159
-0
lite/kernels/cuda/search_group_padding_compute.h
lite/kernels/cuda/search_group_padding_compute.h
+38
-0
lite/kernels/cuda/search_group_padding_compute_test.cc
lite/kernels/cuda/search_group_padding_compute_test.cc
+127
-0
未找到文件。
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
ef5dbd1d
...
...
@@ -5,6 +5,7 @@ endif()
message
(
STATUS
"compile with lite CUDA kernels"
)
add_kernel
(
mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
context
)
add_kernel
(
search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS
${
lite_kernel_deps
}
)
...
...
@@ -44,11 +45,12 @@ nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_
nv_test
(
relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda
)
nv_test
(
yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda
)
nv_test
(
transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda
)
nv_test
(
search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda
)
nv_test
(
concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda
)
nv_test
(
elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda
)
nv_test
(
softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda
)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test
(
mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda
)
nv_test
(
mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda
)
nv_test
(
dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda
)
nv_test
(
bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda
)
nv_test
(
sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda
)
...
...
lite/kernels/cuda/search_group_padding_compute.cu
0 → 100644
浏览文件 @
ef5dbd1d
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_group_padding_compute.h"
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
using
Tensor
=
lite
::
Tensor
;
template
<
typename
Dtype
>
__global__
void
ker_search_group_padding
(
Dtype
*
out_emb_padding_data
,
Dtype
*
out_padding_data
,
const
Dtype
*
in_data
,
const
uint64_t
*
offset
,
const
int
seq_num
,
const
int
max_len
,
const
int
emb_size
,
const
Dtype
pad_id
,
const
int
count
)
{
CUDA_KERNEL_LOOP
(
tid
,
count
)
{
int
emb_id
=
tid
%
emb_size
;
int
word_id
=
tid
/
emb_size
;
int
seq_id
=
word_id
/
max_len
;
int
word_id_in_seq
=
word_id
%
max_len
;
int
cur_len
=
offset
[
seq_id
+
1
]
-
offset
[
seq_id
];
if
(
word_id_in_seq
<
cur_len
)
{
out_emb_padding_data
[
tid
]
=
in_data
[(
offset
[
seq_id
]
+
word_id_in_seq
)
*
emb_size
+
emb_id
];
}
else
{
out_emb_padding_data
[
tid
]
=
0.
f
;
out_padding_data
[
word_id
]
=
pad_id
;
}
}
}
void
SearchGroupPaddingCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
cuda_stream
=
ctx
.
exec_stream
();
const
Tensor
*
x
=
param
.
x
;
Tensor
*
out_emb_padding
=
param
.
out_emb_padding
;
Tensor
*
out_new
=
param
.
out_new
;
Tensor
*
out_padding
=
param
.
out_padding
;
const
float
pad_id
=
static_cast
<
float
>
(
param
.
pad_id
);
const
float
*
in_data
=
x
->
data
<
float
>
();
float
*
out_emb_padding_data
=
out_emb_padding
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
out_new_data
=
out_new
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
out_padding_data
=
out_padding
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
const
auto
&
in_seq_offset
=
x
->
lod
()[
0
];
int
batch
=
in_seq_offset
.
size
()
-
1
;
int
max_seq
=
0
;
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
if
(
in_seq_offset
[
i
+
1
]
-
in_seq_offset
[
i
]
>
max_seq
)
{
max_seq
=
in_seq_offset
[
i
+
1
]
-
in_seq_offset
[
i
];
}
}
std
::
vector
<
size_t
>
new_offset
;
new_offset
.
resize
(
batch
+
1
);
for
(
int
i
=
0
;
i
<
batch
+
1
;
++
i
)
{
new_offset
[
i
]
=
i
*
max_seq
;
}
std
::
vector
<
int64_t
>
x_dims
=
x
->
dims
().
Vectorize
();
LoD
out_emb_padding_lod
;
out_emb_padding_lod
.
push_back
(
new_offset
);
out_emb_padding
->
set_lod
(
out_emb_padding_lod
);
out_emb_padding
->
Resize
({
batch
*
max_seq
,
x_dims
[
1
]});
LoD
out_new_lod
;
out_new_lod
.
push_back
(
in_seq_offset
);
out_new
->
set_lod
(
out_new_lod
);
out_new
->
Resize
({
x_dims
[
0
],
1
});
LoD
out_padding_lod
;
out_padding_lod
.
push_back
(
new_offset
);
out_padding
->
set_lod
(
out_padding_lod
);
out_padding
->
Resize
({
batch
*
max_seq
,
1
});
const
int
count
=
out_emb_padding
->
numel
();
const
auto
&
out_emb_padding_seq_offset
=
out_emb_padding
->
lod
()[
0
];
int
max_len
=
out_emb_padding_seq_offset
[
1
];
int
seq_num
=
out_emb_padding_seq_offset
.
size
()
-
1
;
int
emb_size
=
x
->
dims
()[
1
];
_in_seq_offset
.
Resize
({
seq_num
+
1
,
1
,
1
,
1
});
uint64_t
*
offset_data
=
_in_seq_offset
.
mutable_data
<
uint64_t
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
offset_data
,
in_seq_offset
.
data
(),
sizeof
(
uint64_t
)
*
in_seq_offset
.
size
(),
IoDirection
::
HtoD
,
cuda_stream
);
TargetWrapperCuda
::
MemsetSync
(
out_new_data
,
0
,
out_new
->
dims
()[
0
]
*
out_new
->
dims
()[
1
]
*
sizeof
(
float
));
ker_search_group_padding
<
float
><<<
CUDA_GET_BLOCKS
(
count
),
CUDA_NUM_THREADS
,
0
,
cuda_stream
>>>
(
out_emb_padding_data
,
out_padding_data
,
in_data
,
offset_data
,
seq_num
,
max_len
,
emb_size
,
pad_id
,
count
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
search_group_padding
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
SearchGroupPaddingCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out_emb_padding"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out_new"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out_padding"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
Finalize
();
lite/kernels/cuda/search_group_padding_compute.h
0 → 100644
浏览文件 @
ef5dbd1d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
SearchGroupPaddingCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
SearchGroupPaddingParam
;
void
Run
()
override
;
virtual
~
SearchGroupPaddingCompute
()
=
default
;
private:
lite
::
Tensor
_in_seq_offset
;
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/search_group_padding_compute_test.cc
0 → 100644
浏览文件 @
ef5dbd1d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_group_padding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
TEST
(
search_group_padding_cuda
,
run_test
)
{
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
out_emb_padding
,
out_emb_padding_cpu
,
out_emb_padding_ref
;
lite
::
Tensor
out_new
,
out_new_cpu
,
out_new_ref
;
lite
::
Tensor
out_padding
,
out_padding_cpu
,
out_padding_ref
;
int
x_dims0
=
2
;
int
x_dims1
=
3
;
x
.
Resize
({
x_dims0
,
x_dims1
});
x_cpu
.
Resize
({
x_dims0
,
x_dims1
});
x_ref
.
Resize
({
x_dims0
,
x_dims1
});
out_emb_padding
.
Resize
({
1
,
x_dims1
});
out_emb_padding_cpu
.
Resize
({
1
,
x_dims1
});
out_emb_padding_ref
.
Resize
({
1
,
x_dims1
});
out_new
.
Resize
({
x_dims0
,
1
});
out_new_cpu
.
Resize
({
x_dims0
,
1
});
out_new_ref
.
Resize
({
x_dims0
,
1
});
out_padding
.
Resize
({
1
,
1
});
out_padding_cpu
.
Resize
({
1
,
1
});
out_padding_ref
.
Resize
({
1
,
1
});
LoD
x_lod
{};
x_lod
.
push_back
({
0
,
1
});
x
.
set_lod
(
x_lod
);
auto
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
auto
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
auto
*
out_emb_padding_data
=
out_emb_padding
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
out_emb_padding_cpu_data
=
out_emb_padding_cpu
.
mutable_data
<
float
>
();
auto
*
out_emb_padding_ref_data
=
out_emb_padding_ref
.
mutable_data
<
float
>
();
auto
*
out_new_data
=
out_new
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
out_new_cpu_data
=
out_new_cpu
.
mutable_data
<
float
>
();
auto
*
out_new_ref_data
=
out_new_ref
.
mutable_data
<
float
>
();
auto
*
out_padding_data
=
out_padding
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
out_padding_cpu_data
=
out_padding_cpu
.
mutable_data
<
float
>
();
auto
*
out_padding_ref_data
=
out_padding_ref
.
mutable_data
<
float
>
();
for
(
int64_t
i
=
0
;
i
<
x_cpu
.
dims
().
production
();
i
++
)
{
x_cpu_data
[
i
]
=
static_cast
<
float
>
(
i
);
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
);
}
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
out_emb_padding_ref_data
[
0
]
=
0.
f
;
out_emb_padding_ref_data
[
1
]
=
1.
f
;
out_emb_padding_ref_data
[
2
]
=
2.
f
;
out_new_ref_data
[
0
]
=
0.
f
;
out_new_ref_data
[
1
]
=
0.
f
;
out_padding_ref_data
[
0
]
=
0.
f
;
SearchGroupPaddingCompute
sgp_kernel
;
operators
::
SearchGroupPaddingParam
param
;
param
.
x
=
&
x
;
param
.
out_emb_padding
=
&
out_emb_padding
;
param
.
out_new
=
&
out_new
;
param
.
out_padding
=
&
out_padding
;
sgp_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
sgp_kernel
.
SetContext
(
std
::
move
(
ctx
));
sgp_kernel
.
Launch
();
cudaDeviceSynchronize
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_emb_padding_cpu_data
,
out_emb_padding_data
,
sizeof
(
float
)
*
out_emb_padding
.
numel
(),
IoDirection
::
DtoH
);
CopySync
<
TARGET
(
kCUDA
)
>
(
out_new_cpu_data
,
out_new_data
,
sizeof
(
float
)
*
out_new
.
numel
(),
IoDirection
::
DtoH
);
CopySync
<
TARGET
(
kCUDA
)
>
(
out_padding_cpu_data
,
out_padding_data
,
sizeof
(
float
)
*
out_padding
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_emb_padding_cpu
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
out_emb_padding_cpu_data
[
i
],
out_emb_padding_ref_data
[
i
],
1e-5
);
}
for
(
int
i
=
0
;
i
<
out_new_cpu
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
out_new_cpu_data
[
i
],
out_new_ref_data
[
i
],
1e-5
);
}
for
(
int
i
=
0
;
i
<
out_padding_cpu
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
out_padding_cpu_data
[
i
],
out_padding_ref_data
[
i
],
1e-5
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
USE_LITE_KERNEL
(
search_group_padding
,
kCUDA
,
kFloat
,
kNCHW
,
def
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录