Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
8eeaa0ac
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
337
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8eeaa0ac
编写于
7月 20, 2020
作者:
H
HappyAngel
提交者:
GitHub
7月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #138 from PaddlePaddle/develop
pull
上级
94d8cb65
85a12dab
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
1043 addition
and
46 deletion
+1043
-46
cmake/external/flatbuffers.cmake
cmake/external/flatbuffers.cmake
+1
-3
docs/demo_guides/opencl.md
docs/demo_guides/opencl.md
+14
-3
lite/api/paddle_api.cc
lite/api/paddle_api.cc
+13
-0
lite/api/paddle_api.h
lite/api/paddle_api.h
+3
-0
lite/backends/cuda/math/activation.cu
lite/backends/cuda/math/activation.cu
+71
-0
lite/backends/cuda/math/activation.h
lite/backends/cuda/math/activation.h
+3
-0
lite/backends/cuda/math/bias.cu
lite/backends/cuda/math/bias.cu
+12
-0
lite/backends/cuda/math/gru_forward.cu
lite/backends/cuda/math/gru_forward.cu
+141
-0
lite/backends/cuda/math/gru_forward.h
lite/backends/cuda/math/gru_forward.h
+163
-0
lite/backends/cuda/math/sequence2batch.cu
lite/backends/cuda/math/sequence2batch.cu
+5
-0
lite/backends/cuda/math/sequence2batch.h
lite/backends/cuda/math/sequence2batch.h
+37
-0
lite/backends/opencl/cl_runtime.cc
lite/backends/opencl/cl_runtime.cc
+22
-12
lite/backends/opencl/cl_runtime.h
lite/backends/opencl/cl_runtime.h
+28
-3
lite/backends/opencl/cl_wrapper.cc
lite/backends/opencl/cl_wrapper.cc
+9
-4
lite/backends/opencl/cl_wrapper.h
lite/backends/opencl/cl_wrapper.h
+8
-1
lite/core/mir/fusion/quant_dequant_op_fuser.cc
lite/core/mir/fusion/quant_dequant_op_fuser.cc
+6
-1
lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc
lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc
+22
-0
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+2
-0
lite/kernels/cuda/concat_compute_test.cc
lite/kernels/cuda/concat_compute_test.cc
+1
-1
lite/kernels/cuda/gru_compute.cu
lite/kernels/cuda/gru_compute.cu
+175
-14
lite/kernels/cuda/gru_compute_test.cc
lite/kernels/cuda/gru_compute_test.cc
+43
-0
lite/kernels/cuda/sigmoid_compute.cu
lite/kernels/cuda/sigmoid_compute.cu
+57
-0
lite/kernels/cuda/sigmoid_compute.h
lite/kernels/cuda/sigmoid_compute.h
+35
-0
lite/kernels/cuda/sigmoid_compute_test.cc
lite/kernels/cuda/sigmoid_compute_test.cc
+168
-0
lite/tools/build.sh
lite/tools/build.sh
+2
-2
lite/tools/ci_build.sh
lite/tools/ci_build.sh
+2
-2
未找到文件。
cmake/external/flatbuffers.cmake
浏览文件 @
8eeaa0ac
...
...
@@ -94,12 +94,10 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
message
(
STATUS
"SRC_FBS_DIR:
${
SRC_FBS_DIR
}
"
)
string
(
REGEX REPLACE
"
\\
.fbs$"
"_generated.h"
GEN_HEADER
${
SRC_FBS
}
)
add_custom_command
(
OUTPUT
${
GEN_HEADER
}
OUTPUT
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
GEN_HEADER
}
"
COMMAND
"
${
FLATBUFFERS_FLATC_EXECUTABLE
}
"
--cpp --gen-mutable --gen-object-api --reflect-names
--force-empty --force-empty-vectors
${
OPT
}
-I
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/tests/include_test"
-o
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SRC_FBS_DIR
}
"
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SRC_FBS
}
"
DEPENDS flatbuffers
...
...
docs/demo_guides/opencl.md
浏览文件 @
8eeaa0ac
...
...
@@ -37,14 +37,25 @@ rm ./lite/api/paddle_use_kernels.h
rm
./lite/api/paddle_use_ops.h
# 设置编译参数并开始编译
# android-armv7:cpu+gpu+cv+extra
./lite/tools/build_android.sh
\
--arch
=
armv7
\
--toolchain
=
clang
\
--with_cv
=
OFF
\
--with_log
=
OFF
\
--with_extra
=
OFF
\
--with_extra
=
ON
\
--with_cv
=
ON
\
--with_opencl
=
ON
# android-armv8:cpu+gpu+cv+extra
./lite/tools/build_android.sh
\
--arch
=
armv8
\
--toolchain
=
clang
\
--with_log
=
OFF
\
--with_extra
=
ON
\
--with_cv
=
ON
\
--with_opencl
=
ON
# 注:编译帮助请执行: ./lite/tools/build_android.sh help
```
...
...
@@ -206,7 +217,7 @@ adb shell "export GLOG_v=4; \
## 3. 如何在Code中使用
即编译产物
`demo/cxx/mobile_light`
目录下的代码,在线版参考GitHub仓库
[
./lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc
](
https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc
)
;
即编译产物
`demo/cxx/mobile_light`
目录下的代码,在线版参考GitHub仓库
[
./lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc
](
https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc
)
,其中也包括判断当前设备是否支持OpenCL的方法
;
注:这里给出的链接会跳转到线上最新develop分支的代码,很可能与您本地的代码存在差异,建议参考自己本地位于
`lite/demo/cxx/`
目录的代码,查看如何使用。
...
...
lite/api/paddle_api.cc
浏览文件 @
8eeaa0ac
...
...
@@ -32,9 +32,22 @@
#include "lite/backends/mlu/target_wrapper.h"
#endif
#ifdef LITE_WITH_OPENCL
#include "lite/backends/opencl/cl_runtime.h"
#endif
namespace
paddle
{
namespace
lite_api
{
bool
IsOpenCLBackendValid
()
{
bool
opencl_valid
=
false
;
#ifdef LITE_WITH_OPENCL
opencl_valid
=
paddle
::
lite
::
CLRuntime
::
Global
()
->
OpenCLAvaliableForDevice
();
#endif
LOG
(
INFO
)
<<
"opencl_valid:"
<<
opencl_valid
;
return
opencl_valid
;
}
Tensor
::
Tensor
(
void
*
raw
)
:
raw_tensor_
(
raw
)
{}
// TODO(Superjomn) refine this by using another `const void* const_raw`;
...
...
lite/api/paddle_api.h
浏览文件 @
8eeaa0ac
...
...
@@ -33,6 +33,9 @@ using lod_t = std::vector<std::vector<uint64_t>>;
enum
class
LiteModelType
{
kProtobuf
=
0
,
kNaiveBuffer
,
UNK
};
// return true if current device supports OpenCL model
LITE_API
bool
IsOpenCLBackendValid
();
struct
LITE_API
Tensor
{
explicit
Tensor
(
void
*
raw
);
explicit
Tensor
(
const
void
*
raw
);
...
...
lite/backends/cuda/math/activation.cu
浏览文件 @
8eeaa0ac
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/utils.h"
...
...
@@ -484,6 +485,76 @@ template void relu(int, const half*, half*, float, cudaStream_t);
template
void
bias_relu
(
int
,
const
float
*
,
const
float
*
bias
,
float
*
,
float
,
cudaStream_t
);
// ------------- sigmoid -------------
template
<
typename
T
>
__global__
void
sigmoid_kernel
(
const
int
num
,
const
T
*
in
,
T
*
out
)
{
CUDA_KERNEL_LOOP
(
i
,
num
)
{
#if __CUDA_ARCH__ >= 350
out
[
i
]
=
static_cast
<
T
>
(
1.0
f
)
/
(
static_cast
<
T
>
(
1.0
f
)
+
expf
(
-
1
*
__ldg
(
in
+
i
)));
#else
out
[
i
]
=
static_cast
<
T
>
(
1.0
f
)
/
(
static_cast
<
T
>
(
1.0
f
)
+
expf
(
-
in
[
i
]));
#endif
}
}
template
<
>
__global__
void
sigmoid_kernel
(
const
int
num
,
const
half
*
in
,
half
*
out
)
{
CUDA_KERNEL_LOOP
(
i
,
num
)
{
half
tmp
=
__float2half
(
1.0
f
);
#if __CUDA_ARCH__ >= 530
out
[
i
]
=
__hdiv
(
tmp
,
__hadd
(
tmp
,
hexp
(
__hmul
(
__float2half
(
-
1.0
f
),
__ldg
(
in
+
i
)))));
#else
out
[
i
]
=
__float2half
(
1.0
f
/
(
1.0
f
+
expf
(
-
1
*
__half2float
(
in
[
i
]))));
#endif
}
}
template
<
>
__global__
void
sigmoid_kernel
(
const
int
num
,
const
half2
*
in
,
half2
*
out
)
{
CUDA_KERNEL_LOOP
(
i
,
num
)
{
half2
tmp
=
__floats2half2_rn
(
1.0
f
,
1.0
f
);
#if __CUDA_ARCH__ >= 530
out
[
i
]
=
__h2div
(
tmp
,
__hadd2
(
tmp
,
h2exp
(
__hmul2
(
__floats2half2_rn
(
-
1.0
f
,
-
1.0
f
),
__ldg
(
in
+
i
)))));
#else
out
[
i
].
x
=
__float2half
(
1.0
f
/
(
1.0
f
+
expf
(
-
1
*
__half2float
(
in
[
i
].
x
))));
out
[
i
].
y
=
__float2half
(
1.0
f
/
(
1.0
f
+
expf
(
-
1
*
__half2float
(
in
[
i
].
y
))));
#endif
}
}
template
<
typename
T
>
void
sigmoid
(
const
int
num
,
const
T
*
din
,
T
*
dout
,
cudaStream_t
stream
)
{
sigmoid_kernel
<
T
><<<
CUDA_GET_BLOCKS
(
num
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
num
,
din
,
dout
);
CUDA_POST_KERNEL_CHECK
;
}
template
<
>
void
sigmoid
(
const
int
num
,
const
half
*
din
,
half
*
dout
,
cudaStream_t
stream
)
{
if
(
num
%
2
==
0
)
{
const
half2
*
din2
=
reinterpret_cast
<
const
half2
*>
(
din
);
half2
*
dout2
=
reinterpret_cast
<
half2
*>
(
dout
);
sigmoid_kernel
<
half2
><<<
CUDA_GET_BLOCKS
(
num
/
2
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
num
/
2
,
din2
,
dout2
);
}
else
{
sigmoid_kernel
<
half
><<<
CUDA_GET_BLOCKS
(
num
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
num
,
din
,
dout
);
}
CUDA_POST_KERNEL_CHECK
;
}
template
void
sigmoid
(
const
int
num
,
const
float
*
din
,
float
*
dout
,
cudaStream_t
stream
);
}
// namespace math
}
// namespace cuda
}
// namespace lite
...
...
lite/backends/cuda/math/activation.h
浏览文件 @
8eeaa0ac
...
...
@@ -83,6 +83,9 @@ void bias_int8_nhwc(int num,
const
void
*
scale
,
cudaStream_t
stream
);
template
<
typename
T
>
void
sigmoid
(
const
int
num
,
const
T
*
din
,
T
*
dout
,
cudaStream_t
stream
);
}
// namespace math
}
// namespace cuda
}
// namespace lite
...
...
lite/backends/cuda/math/bias.cu
浏览文件 @
8eeaa0ac
...
...
@@ -31,6 +31,17 @@ __global__ void RowwiseAddKernel(
c
[
i
]
=
a
[
i
]
+
b
[
w
];
}
}
template
<
>
__global__
void
RowwiseAddKernel
(
const
half
*
a
,
const
half
*
b
,
half
*
c
,
int
width
,
int
num
)
{
CUDA_KERNEL_LOOP
(
i
,
num
)
{
int
h
=
i
/
width
;
int
w
=
i
-
h
*
width
;
c
[
i
]
=
__hadd
(
a
[
i
],
b
[
w
]);
}
}
template
<
typename
T
>
void
RowwiseAdd
<
T
>::
operator
()(
const
T
*
input
,
const
T
*
bias
,
...
...
@@ -44,6 +55,7 @@ void RowwiseAdd<T>::operator()(const T* input,
}
template
struct
RowwiseAdd
<
float
>;
template
struct
RowwiseAdd
<
half
>;
}
// namespace math
}
// namespace cuda
...
...
lite/backends/cuda/math/gru_forward.cu
浏览文件 @
8eeaa0ac
...
...
@@ -22,6 +22,10 @@ namespace lite {
namespace
cuda
{
namespace
math
{
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template
<
typename
T
>
__global__
void
GruForwardResetOutput
(
T
*
gate_value
,
...
...
@@ -33,6 +37,7 @@ __global__ void GruForwardResetOutput(
bool
is_batch
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
if
(
is_batch
)
{
batch_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
...
@@ -44,12 +49,14 @@ __global__ void GruForwardResetOutput(
T
reset_out_val
;
T
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
T
reset_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
1
];
if
(
prev_output_value
)
{
if
(
is_batch
)
{
prev_output_value
+=
batch_idx
*
frame_size
;
}
prev_out
=
prev_output_value
[
frame_idx
];
}
if
(
active_gate
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
update_gate_value
=
Sigmoid
(
update_gate_value
);
reset_gate_value
=
Sigmoid
(
reset_gate_value
);
...
...
@@ -60,12 +67,71 @@ __global__ void GruForwardResetOutput(
update_gate_value
=
Tanh
(
update_gate_value
);
reset_gate_value
=
Tanh
(
reset_gate_value
);
}
reset_out_val
=
prev_out
*
reset_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
0
]
=
update_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
1
]
=
reset_gate_value
;
reset_output_value
[
frame_idx
]
=
reset_out_val
;
}
template
<
>
__global__
void
GruForwardResetOutput
(
half
*
gate_value
,
half
*
reset_output_value
,
half
*
prev_output_value
,
int
frame_size
,
int
batch_size
,
lite
::
cuda
::
math
::
ActivationType
active_gate
,
bool
is_batch
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
if
(
is_batch
)
{
batch_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch_idx
>=
batch_size
)
return
;
gate_value
+=
batch_idx
*
3
*
frame_size
;
reset_output_value
+=
batch_idx
*
frame_size
;
}
half
prev_out
=
0
;
half
reset_out_val
;
half
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
half
reset_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
1
];
if
(
prev_output_value
)
{
if
(
is_batch
)
{
prev_output_value
+=
batch_idx
*
frame_size
;
}
prev_out
=
prev_output_value
[
frame_idx
];
}
if
(
active_gate
==
ActivationType
::
kSigmoid
)
{
update_gate_value
=
Sigmoid
(
update_gate_value
);
reset_gate_value
=
Sigmoid
(
reset_gate_value
);
}
else
if
(
active_gate
==
ActivationType
::
kReLU
)
{
update_gate_value
=
ReLU
(
update_gate_value
);
reset_gate_value
=
ReLU
(
reset_gate_value
);
}
else
if
(
active_gate
==
ActivationType
::
kTanh
)
{
update_gate_value
=
Tanh
(
update_gate_value
);
reset_gate_value
=
Tanh
(
reset_gate_value
);
}
#if __CUDA_ARCH__ >= 530
reset_out_val
=
__hmul
(
prev_out
,
reset_gate_value
);
#else
reset_out_val
=
__float2half
(
__half2float
(
prev_out
)
*
__half2float
(
reset_gate_value
));
#endif
gate_value
[
frame_idx
+
frame_size
*
0
]
=
update_gate_value
;
gate_value
[
frame_idx
+
frame_size
*
1
]
=
reset_gate_value
;
reset_output_value
[
frame_idx
]
=
reset_out_val
;
}
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template
<
typename
T
>
__global__
void
GruForwardFinalOutput
(
T
*
gate_value
,
...
...
@@ -87,14 +153,17 @@ __global__ void GruForwardFinalOutput(
gate_value
+=
batch_idx
*
3
*
frame_size
;
output_value
+=
batch_idx
*
frame_size
;
}
T
output
;
T
prev_out
=
0
;
T
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
T
state_frame_value
=
gate_value
[
frame_idx
+
frame_size
*
2
];
if
(
prev_output_value
)
{
if
(
is_batch
)
prev_output_value
+=
batch_idx
*
frame_size
;
prev_out
=
prev_output_value
[
frame_idx
];
}
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
state_frame_value
=
Sigmoid
(
state_frame_value
);
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kReLU
)
{
...
...
@@ -102,6 +171,7 @@ __global__ void GruForwardFinalOutput(
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kTanh
)
{
state_frame_value
=
Tanh
(
state_frame_value
);
}
if
(
origin_mode
)
{
output
=
update_gate_value
*
prev_out
+
state_frame_value
-
update_gate_value
*
state_frame_value
;
...
...
@@ -109,6 +179,76 @@ __global__ void GruForwardFinalOutput(
output
=
prev_out
-
update_gate_value
*
prev_out
+
update_gate_value
*
state_frame_value
;
}
gate_value
[
frame_idx
+
frame_size
*
2
]
=
state_frame_value
;
output_value
[
frame_idx
]
=
output
;
}
template
<
>
__global__
void
GruForwardFinalOutput
(
half
*
gate_value
,
half
*
prev_output_value
,
half
*
output_value
,
int
frame_size
,
int
batch_size
,
lite
::
cuda
::
math
::
ActivationType
active_node
,
bool
origin_mode
,
bool
is_batch
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
if
(
is_batch
)
{
batch_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch_idx
>=
batch_size
)
{
return
;
}
gate_value
+=
batch_idx
*
3
*
frame_size
;
output_value
+=
batch_idx
*
frame_size
;
}
half
output
;
half
prev_out
=
0
;
half
update_gate_value
=
gate_value
[
frame_idx
+
frame_size
*
0
];
half
state_frame_value
=
gate_value
[
frame_idx
+
frame_size
*
2
];
if
(
prev_output_value
)
{
if
(
is_batch
)
prev_output_value
+=
batch_idx
*
frame_size
;
prev_out
=
prev_output_value
[
frame_idx
];
}
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kSigmoid
)
{
state_frame_value
=
Sigmoid
(
state_frame_value
);
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kReLU
)
{
state_frame_value
=
ReLU
(
state_frame_value
);
}
else
if
(
active_node
==
lite
::
cuda
::
math
::
ActivationType
::
kTanh
)
{
state_frame_value
=
Tanh
(
state_frame_value
);
}
if
(
origin_mode
)
{
#if __CUDA_ARCH__ >= 530
output
=
__hsub
(
__hadd
(
__hmul
(
update_gate_value
,
prev_out
),
state_frame_value
),
__hmul
(
update_gate_value
,
state_frame_value
));
#else
output
=
__float2half
(
__half2float
(
update_gate_value
)
*
__half2float
(
prev_out
)
+
__half2float
(
state_frame_value
)
-
__half2float
(
update_gate_value
)
*
__half2float
(
state_frame_value
));
#endif
}
else
{
#if __CUDA_ARCH__ >= 530
output
=
prev_out
-
update_gate_value
*
prev_out
+
update_gate_value
*
state_frame_value
;
output
=
__hadd
(
__hsub
(
prev_out
,
__hmul
(
update_gate_value
,
prev_out
)),
__hmul
(
update_gate_value
,
state_frame_value
));
#else
output
=
__float2half
(
__half2float
(
prev_out
)
-
__half2float
(
update_gate_value
)
*
__half2float
(
prev_out
)
+
__half2float
(
update_gate_value
)
*
__half2float
(
state_frame_value
));
#endif
}
gate_value
[
frame_idx
+
frame_size
*
2
]
=
state_frame_value
;
output_value
[
frame_idx
]
=
output
;
}
...
...
@@ -122,6 +262,7 @@ template __global__ void GruForwardFinalOutput<float>(
lite
::
cuda
::
math
::
ActivationType
active_node
,
bool
origin_mode
,
bool
is_batch
);
template
__global__
void
GruForwardResetOutput
<
float
>(
float
*
gate_value
,
float
*
reset_output_value
,
...
...
lite/backends/cuda/math/gru_forward.h
浏览文件 @
8eeaa0ac
...
...
@@ -34,10 +34,32 @@ template <typename Dtype>
inline
__device__
Dtype
Sigmoid
(
const
Dtype
a
)
{
return
static_cast
<
Dtype
>
(
1.0
)
/
(
static_cast
<
Dtype
>
(
1.0
)
+
expf
(
-
a
));
}
template
<
>
inline
__device__
half
Sigmoid
(
const
half
a
)
{
#if __CUDA_ARCH__ >= 530
const
half
tmp
=
__float2half
(
1.0
f
);
return
__hdiv
(
tmp
,
__hadd
(
tmp
,
hexp
(
__hmul
(
__float2half
(
-
1.
f
),
a
))));
#else
return
__float2half
(
1.0
f
/
(
expf
(
__half2float
(
a
)
*
-
1
)
+
1.0
f
));
#endif
}
template
<
typename
Dtype
>
inline
__device__
Dtype
ReLU
(
const
Dtype
a
)
{
return
a
>
static_cast
<
Dtype
>
(
0.
f
)
?
a
:
static_cast
<
Dtype
>
(
0.
f
);
}
template
<
>
inline
__device__
half
ReLU
(
const
half
a
)
{
const
half
tmp
=
__float2half
(
0.
f
);
#if __CUDA_ARCH__ >= 530
return
__hgt
(
a
,
tmp
)
?
a
:
tmp
;
#else
return
__float2half
(
__half2float
(
a
)
>
0.
f
?
__half2float
(
a
)
:
0.
f
);
#endif
}
template
<
typename
Dtype
>
inline
__device__
Dtype
Tanh
(
const
Dtype
a
)
{
Dtype
tmp
=
static_cast
<
Dtype
>
(
-
2.0
)
*
a
;
...
...
@@ -45,6 +67,18 @@ inline __device__ Dtype Tanh(const Dtype a) {
static_cast
<
Dtype
>
(
1.0
);
}
template
<
>
inline
__device__
half
Tanh
(
const
half
a
)
{
#if __CUDA_ARCH__ >= 530
half
tmp
=
__float2half
(
1.0
f
);
half
numerator
=
__hmul
(
__float2half
(
-
2.0
f
),
a
);
return
__hsub
(
__hdiv
(
__float2half
(
2.0
f
),
__hadd
(
tmp
,
hexp
(
numerator
))),
tmp
);
#else
float
tmp
=
-
2.0
f
*
__half2float
(
a
);
return
__float2half
(
2.0
f
/
(
1.0
f
+
expf
(
tmp
))
-
1.0
f
);
#endif
}
template
<
typename
T
>
__global__
void
GruForwardResetOutput
(
T
*
gate_value
,
...
...
@@ -54,6 +88,7 @@ __global__ void GruForwardResetOutput(
int
batch_size
,
lite
::
cuda
::
math
::
ActivationType
active_gate
,
bool
is_batch
);
template
<
typename
T
>
__global__
void
GruForwardFinalOutput
(
T
*
gate_value
,
...
...
@@ -65,6 +100,134 @@ __global__ void GruForwardFinalOutput(
bool
origin_mode
,
bool
is_batch
);
/*
* threads(tile_size, 1)
* grids(frame_blocks, 1)
*/
template
<
class
T
,
int
TiledSize
>
__global__
void
FastCollectiveGruGate
(
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_weight
,
T
*
reset_output
,
int
frame_size
,
ActivationType
active_node
)
{
T
xt_0
=
0.0
f
;
T
a0
=
0.0
f
;
T
c0
=
0.0
f
;
T
b0
[
TiledSize
];
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tiled_mask
=
((
1
<<
TiledSize
)
-
1
);
// tiled matrix multiply using register shift, faster than sm.
if
(
prev_output_value
)
{
for
(
int
k
=
0
;
k
<
(((
frame_size
-
1
)
/
TiledSize
)
+
1
);
++
k
)
{
a0
=
0
;
if
((
threadIdx
.
x
+
k
*
TiledSize
)
<
frame_size
)
{
a0
=
prev_output_value
[
threadIdx
.
x
+
(
k
*
TiledSize
)];
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
if
(
col
<
frame_size
*
2
&&
(
i
+
k
*
TiledSize
)
<
frame_size
)
{
b0
[
i
]
=
gate_weight
[(
i
+
k
*
TiledSize
)
*
frame_size
*
2
+
col
];
}
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0
=
c0
+
__shfl_sync
(
tiled_mask
,
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#else
c0
=
c0
+
__shfl
(
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#endif
}
}
}
__syncthreads
();
if
(
col
<
frame_size
*
2
)
{
xt_0
=
gate_value
[
col
];
c0
+=
xt_0
;
if
(
active_node
==
ActivationType
::
kSigmoid
)
{
c0
=
Sigmoid
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kReLU
)
{
c0
=
ReLU
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kTanh
)
{
c0
=
Tanh
(
c0
);
}
gate_value
[
col
]
=
c0
;
if
(
frame_size
<=
col
&&
col
<
frame_size
*
2
)
{
T
htp_0
=
0.0
;
if
(
prev_output_value
)
{
htp_0
=
prev_output_value
[
col
-
frame_size
];
}
reset_output
[
col
-
frame_size
]
=
c0
*
htp_0
;
}
else
if
(
col
<
frame_size
)
{
gate_value
[
col
]
=
c0
;
}
}
}
template
<
class
T
,
int
TiledSize
>
__global__
void
FastCollectiveGruOut
(
T
*
gate_weight
,
T
*
prev_out_value
,
T
*
output_value
,
T
*
gate_value
,
T
*
reset_value
,
int
frame_size
,
ActivationType
active_node
,
bool
origin_mode
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T
a0
=
0.0
f
;
T
b0
[
TiledSize
];
T
c0
=
0.0
f
;
int
tiled_mask
=
((
1
<<
TiledSize
)
-
1
);
if
(
prev_out_value
)
{
for
(
int
k
=
0
;
k
<
((
frame_size
-
1
)
/
TiledSize
+
1
);
++
k
)
{
a0
=
0
;
if
((
threadIdx
.
x
+
k
*
TiledSize
)
<
frame_size
)
{
a0
=
reset_value
[
threadIdx
.
x
+
k
*
TiledSize
];
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
if
(
col
<
frame_size
&&
(
i
+
k
*
TiledSize
)
<
frame_size
)
{
b0
[
i
]
=
gate_weight
[(
i
+
k
*
TiledSize
)
*
frame_size
+
col
];
}
}
for
(
int
i
=
0
;
i
<
TiledSize
;
++
i
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0
=
c0
+
__shfl_sync
(
tiled_mask
,
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#else
c0
=
c0
+
__shfl
(
a0
,
i
,
TiledSize
)
*
b0
[
i
];
#endif
}
}
}
__syncthreads
();
if
(
col
<
frame_size
)
{
T
xt_0
=
gate_value
[
col
+
2
*
frame_size
];
T
gta_0
=
gate_value
[
col
];
T
htp_0
=
0
;
if
(
prev_out_value
)
{
htp_0
=
prev_out_value
[
col
];
}
c0
+=
xt_0
;
if
(
active_node
==
ActivationType
::
kSigmoid
)
{
c0
=
Sigmoid
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kReLU
)
{
c0
=
ReLU
(
c0
);
}
else
if
(
active_node
==
ActivationType
::
kTanh
)
{
c0
=
Tanh
(
c0
);
}
gate_value
[
col
+
2
*
frame_size
]
=
c0
;
if
(
origin_mode
)
{
output_value
[
col
]
=
htp_0
*
gta_0
+
(
1
-
gta_0
)
*
c0
;
}
else
{
output_value
[
col
]
=
c0
*
gta_0
+
(
1
-
gta_0
)
*
htp_0
;
}
}
}
}
// namespace math
}
// namespace cuda
}
// namespace lite
...
...
lite/backends/cuda/math/sequence2batch.cu
浏览文件 @
8eeaa0ac
...
...
@@ -77,8 +77,13 @@ void CopyMatrixRowsFunctor<T>::operator()(
}
template
class
CopyMatrixRowsFunctor
<
float
>;
template
class
CopyMatrixRowsFunctor
<
half
>;
template
class
LoDTensor2BatchFunctor
<
float
>;
template
class
LoDTensor2BatchFunctor
<
half
>;
template
class
Batch2LoDTensorFunctor
<
float
>;
template
class
Batch2LoDTensorFunctor
<
half
>;
}
// namespace math
}
// namespace cuda
...
...
lite/backends/cuda/math/sequence2batch.h
浏览文件 @
8eeaa0ac
...
...
@@ -32,6 +32,9 @@ namespace math {
template
<
typename
T
>
class
CopyMatrixRowsFunctor
{
public:
// If is_src_index is true, copy the indexed rows of input src to the output
// dst. If is_src_index is false, copy the input src to the indexed of output
// dst. The indexes rows are based on the input index.
void
operator
()(
const
lite
::
Tensor
&
src
,
lite
::
Tensor
*
dst
,
const
std
::
vector
<
uint64_t
>&
index_lod
,
...
...
@@ -44,6 +47,11 @@ class CopyMatrixRowsFunctor {
template
<
typename
T
>
class
LoDTensor2BatchFunctor
{
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
struct
SeqInfo
{
SeqInfo
(
size_t
start
,
size_t
length
,
size_t
seq_idx
)
:
start_
(
start
),
length_
(
length
),
seq_idx_
(
seq_idx
)
{}
...
...
@@ -60,21 +68,49 @@ class LoDTensor2BatchFunctor {
auto
lods
=
lod_tensor
.
lod
();
CHECK_EQ
(
lods
.
size
(),
1UL
)
<<
"Only support one level sequence now."
;
const
auto
&
lod
=
lods
[
0
];
std
::
vector
<
SeqInfo
>
seq_info
;
for
(
int
seq_id
=
0
;
seq_id
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
seq_id
)
{
size_t
length
=
lod
[
seq_id
+
1
]
-
lod
[
seq_id
];
seq_info
.
emplace_back
(
lod
[
seq_id
],
length
,
seq_id
);
}
std
::
sort
(
seq_info
.
begin
(),
seq_info
.
end
(),
[](
SeqInfo
a
,
SeqInfo
b
)
{
return
a
.
length_
>
b
.
length_
;
});
// Calculate the start position of each batch.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// max_seqlen = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = 0
// batch_start_positions[1] = len(b0)
// batch_start_positions[2] = len(b0) + len(b1)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
LoD
batch_lods
;
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
uint64_t
>
{
0
});
// batch_lods[0] is the start positions for batch LoDTensor
size_t
max_seqlen
=
seq_info
[
0
].
length_
;
batch_lods
[
0
].
resize
(
max_seqlen
+
1
);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods
[
2
].
resize
(
seq_info
.
size
());
auto
*
batch_starts
=
batch_lods
[
0
].
data
();
...
...
@@ -101,6 +137,7 @@ class LoDTensor2BatchFunctor {
}
batch_tensor
->
set_lod
(
batch_lods
);
lite
::
cuda
::
math
::
CopyMatrixRowsFunctor
<
T
>
to_batch
;
to_batch
(
lod_tensor
,
batch_tensor
,
batch_lods
[
1
],
true
,
stream
);
CUDA_POST_KERNEL_CHECK
;
...
...
lite/backends/opencl/cl_runtime.cc
浏览文件 @
8eeaa0ac
...
...
@@ -38,17 +38,20 @@ CLRuntime::~CLRuntime() {
}
bool
CLRuntime
::
Init
()
{
if
(
initialized_
)
{
if
(
i
s_cl_runtime_i
nitialized_
)
{
return
true
;
}
bool
is_platform_init
=
InitializePlatform
();
bool
is_device_init
=
InitializeDevice
();
is_init_success_
=
is_platform_init
&&
is_device_init
;
initialized_
=
true
;
context_
=
CreateContext
();
command_queue_
=
CreateCommandQueue
(
context
());
return
initialized_
;
LOG
(
INFO
)
<<
"is_platform_init:"
<<
is_platform_init
;
LOG
(
INFO
)
<<
"is_device_init:"
<<
is_device_init
;
if
((
is_platform_init
==
true
)
&&
(
is_device_init
==
true
))
{
is_platform_device_init_success_
=
true
;
context_
=
CreateContext
();
command_queue_
=
CreateCommandQueue
(
context
());
is_cl_runtime_initialized_
=
true
;
}
return
is_cl_runtime_initialized_
;
}
cl
::
Platform
&
CLRuntime
::
platform
()
{
...
...
@@ -64,7 +67,9 @@ cl::Context& CLRuntime::context() {
}
cl
::
Device
&
CLRuntime
::
device
()
{
CHECK
(
device_
!=
nullptr
)
<<
"device_ is not initialized!"
;
if
(
device_
==
nullptr
)
{
LOG
(
ERROR
)
<<
"device_ is not initialized!"
;
}
return
*
device_
;
}
...
...
@@ -150,6 +155,14 @@ GpuType CLRuntime::ParseGpuTypeFromDeviceName(std::string device_name) {
}
bool
CLRuntime
::
InitializeDevice
()
{
VLOG
(
3
)
<<
"device_info_.size():"
<<
device_info_
.
size
();
for
(
auto
i
:
device_info_
)
{
VLOG
(
3
)
<<
">>> "
<<
i
.
first
<<
" "
<<
i
.
second
;
}
if
(
device_info_
.
size
()
>
0
&&
device_info_
.
size
()
<=
2
)
{
return
false
;
}
device_info_
[
"PLACEHOLDER"
]
=
1
;
// ===================== BASIC =====================
// CL_DEVICE_TYPE_GPU
// CL_DEVICE_NAME
...
...
@@ -160,7 +173,7 @@ bool CLRuntime::InitializeDevice() {
status_
=
platform_
->
getDevices
(
CL_DEVICE_TYPE_GPU
,
&
all_devices
);
CL_CHECK_ERROR
(
status_
);
if
(
all_devices
.
empty
())
{
LOG
(
FATAL
)
<<
"No
OpenCL GPU device found!"
;
LOG
(
ERROR
)
<<
"No available
OpenCL GPU device found!"
;
return
false
;
}
device_
=
std
::
make_shared
<
cl
::
Device
>
();
...
...
@@ -313,9 +326,6 @@ bool CLRuntime::InitializeDevice() {
}
std
::
map
<
std
::
string
,
size_t
>&
CLRuntime
::
GetDeviceInfo
()
{
if
(
0
!=
device_info_
.
size
())
{
return
device_info_
;
}
InitializeDevice
();
return
device_info_
;
}
...
...
lite/backends/opencl/cl_runtime.h
浏览文件 @
8eeaa0ac
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "lite/backends/opencl/cl_include.h"
#include "lite/backends/opencl/cl_utility.h"
#include "lite/backends/opencl/cl_wrapper.h"
typedef
enum
{
UNKNOWN
=
0
,
...
...
@@ -68,6 +69,28 @@ class CLRuntime {
public:
static
CLRuntime
*
Global
();
bool
OpenCLAvaliableForDevice
()
{
bool
opencl_lib_found
=
paddle
::
lite
::
CLWrapper
::
Global
()
->
OpenclLibFound
();
LOG
(
INFO
)
<<
"opencl_lib_found:"
<<
opencl_lib_found
;
if
(
opencl_lib_found
==
false
)
return
false
;
bool
dlsym_success
=
paddle
::
lite
::
CLWrapper
::
Global
()
->
DlsymSuccess
();
LOG
(
INFO
)
<<
"dlsym_success:"
<<
dlsym_success
;
if
(
opencl_lib_found
==
false
)
return
false
;
InitializeDevice
();
bool
support_fp16
=
static_cast
<
bool
>
(
device_info_
[
"CL_DEVICE_EXTENSIONS_FP16"
]);
LOG
(
INFO
)
<<
"support_fp16:"
<<
support_fp16
;
if
(
support_fp16
==
false
)
return
false
;
is_device_avaliable_for_opencl_
=
dlsym_success
&&
opencl_lib_found
&&
support_fp16
;
LOG
(
INFO
)
<<
"is_device_avaliable_for_opencl_:"
<<
is_device_avaliable_for_opencl_
;
return
is_device_avaliable_for_opencl_
;
}
bool
Init
();
cl
::
Platform
&
platform
();
...
...
@@ -85,7 +108,7 @@ class CLRuntime {
bool
BuildProgram
(
cl
::
Program
*
program
,
const
std
::
string
&
options
=
""
);
bool
IsInitSuccess
()
{
return
is_init_success_
;
}
bool
IsInitSuccess
()
{
return
is_
platform_device_
init_success_
;
}
std
::
string
cl_path
()
{
return
cl_path_
;
}
...
...
@@ -167,9 +190,11 @@ class CLRuntime {
cl_int
status_
{
CL_SUCCESS
};
bool
initialized_
{
false
};
bool
is_device_avaliable_for_opencl_
{
false
};
bool
is_cl_runtime_initialized_
{
false
};
bool
is_init_success_
{
false
};
bool
is_
platform_device_
init_success_
{
false
};
};
}
// namespace lite
...
...
lite/backends/opencl/cl_wrapper.cc
浏览文件 @
8eeaa0ac
...
...
@@ -19,14 +19,16 @@ limitations under the License. */
namespace
paddle
{
namespace
lite
{
CLWrapper
*
CLWrapper
::
Global
()
{
static
CLWrapper
wrapper
;
return
&
wrapper
;
}
CLWrapper
::
CLWrapper
()
{
CHECK
(
InitHandle
())
<<
"Fail to initialize the OpenCL library!"
;
InitFunctions
();
opencl_lib_found_
=
InitHandle
();
CHECK
(
opencl_lib_found_
)
<<
"Fail to initialize the OpenCL library!"
;
dlsym_success_
=
InitFunctions
();
}
bool
CLWrapper
::
InitHandle
()
{
...
...
@@ -68,15 +70,17 @@ bool CLWrapper::InitHandle() {
}
}
void
CLWrapper
::
InitFunctions
()
{
bool
CLWrapper
::
InitFunctions
()
{
CHECK
(
handle_
!=
nullptr
)
<<
"The library handle can't be null!"
;
bool
dlsym_success
=
true
;
#define PADDLE_DLSYM(cl_func) \
do { \
cl_func##_ = (cl_func##Type)dlsym(handle_, #cl_func); \
if (cl_func##_ == nullptr) { \
LOG(
FATAL
) << "Cannot find the " << #cl_func \
LOG(
ERROR
) << "Cannot find the " << #cl_func \
<< " symbol in libOpenCL.so!"; \
dlsym_success = false; \
break; \
} \
VLOG(4) << "Loaded the " << #cl_func << " symbol successfully."; \
...
...
@@ -137,6 +141,7 @@ void CLWrapper::InitFunctions() {
PADDLE_DLSYM
(
clEnqueueCopyImage
);
#undef PADDLE_DLSYM
return
dlsym_success
;
}
}
// namespace lite
...
...
lite/backends/opencl/cl_wrapper.h
浏览文件 @
8eeaa0ac
...
...
@@ -508,13 +508,20 @@ class CLWrapper final {
return
clEnqueueCopyImage_
;
}
bool
OpenclLibFound
()
{
return
opencl_lib_found_
;
}
bool
DlsymSuccess
()
{
return
dlsym_success_
;
}
private:
CLWrapper
();
CLWrapper
(
const
CLWrapper
&
)
=
delete
;
CLWrapper
&
operator
=
(
const
CLWrapper
&
)
=
delete
;
bool
InitHandle
();
void
InitFunctions
();
bool
InitFunctions
();
bool
opencl_lib_found_
{
true
};
bool
dlsym_success_
{
true
};
void
*
handle_
{
nullptr
};
clGetPlatformIDsType
clGetPlatformIDs_
{
nullptr
};
clGetPlatformInfoType
clGetPlatformInfo_
{
nullptr
};
clBuildProgramType
clBuildProgram_
{
nullptr
};
...
...
lite/core/mir/fusion/quant_dequant_op_fuser.cc
浏览文件 @
8eeaa0ac
...
...
@@ -175,7 +175,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for
(
int
i
=
0
;
i
<
weight_scale_size
;
i
++
)
{
weight_scale
.
push_back
(
whole_weight_scale
);
}
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
// Arm CPU does not support conv2d_transpose
if
(
quantized_op_type_
!=
"conv2d_transpose"
)
{
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
}
op_desc
.
SetInputScale
(
weight_name
,
weight_scale
);
// change the weight from the float type to int8 type.
...
...
@@ -280,6 +284,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc
.
SetInput
(
"X"
,
{
quantized_op_input
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
dequant_op_out
->
arg
()
->
name
});
}
// Arm CPU does not support conv2d_transpose
if
(
quantized_op_type_
!=
"conv2d_transpose"
)
{
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
}
...
...
lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc
浏览文件 @
8eeaa0ac
...
...
@@ -78,6 +78,28 @@ void RunModel(std::string model_dir,
// 1. Set MobileConfig
MobileConfig
config
;
config
.
set_model_from_file
(
model_dir
);
// NOTE: Use android gpu with opencl, you should ensure:
// first, [compile **cpu+opencl** paddlelite
// lib](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/docs/demo_guides/opencl.md);
// second, [convert and use opencl nb
// model](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/docs/user_guides/opt/opt_bin.md).
//
/* Uncomment code below to enable OpenCL
bool is_opencl_backend_valid = ::IsOpenCLBackendValid();
std::cout << "is_opencl_backend_valid:" << is_opencl_backend_valid <<
std::endl;
if (is_opencl_backend_valid) {
// give opencl nb model dir
config.set_model_from_file(model_dir);
} else {
std::cout << "Unsupport opencl nb model." << std::endl;
exit(1);
// you can give backup cpu nb model instead
// config.set_model_from_file(cpu_nb_model_dir);
}
*/
// NOTE: To load model transformed by model_optimize_tool before
// release/v2.3.0, plese use `set_model_dir` API as listed below.
// config.set_model_dir(model_dir);
...
...
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
8eeaa0ac
...
...
@@ -15,6 +15,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${
add_kernel
(
abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sigmoid_compute_cuda CUDA basic SRCS sigmoid_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS
${
lite_kernel_deps
}
)
...
...
@@ -61,6 +62,7 @@ nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_
nv_test
(
abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda
)
nv_test
(
tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda
)
nv_test
(
relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda
)
nv_test
(
sigmoid_compute_cuda_test SRCS sigmoid_compute_test.cc DEPS sigmoid_compute_cuda
)
nv_test
(
yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda
)
nv_test
(
transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda
)
nv_test
(
search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda
)
...
...
lite/kernels/cuda/concat_compute_test.cc
浏览文件 @
8eeaa0ac
...
...
@@ -69,7 +69,7 @@ void concat_compute_ref(const operators::ConcatParam& param) {
std
::
vector
<
int
>
input_cols
(
input
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
input_i_numel
=
input
[
i
]
->
dims
().
size
()
==
0
?
0
:
1
;
for
(
in
t
didx
=
0
;
didx
<
input
[
i
]
->
dims
().
size
();
++
didx
)
{
for
(
size_
t
didx
=
0
;
didx
<
input
[
i
]
->
dims
().
size
();
++
didx
)
{
input_i_numel
*=
input
[
i
]
->
dims
()[
didx
];
}
int
t_cols
=
input_i_numel
/
rows
;
...
...
lite/kernels/cuda/gru_compute.cu
浏览文件 @
8eeaa0ac
...
...
@@ -48,10 +48,69 @@ struct GRUUnitFunctor {
CUDAContext
*
context
)
{
dim3
threads
,
grids
;
if
(
batch_size
==
1
)
{
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
int
frame_blocks
=
(
frame_size
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame_per_block
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
if
(
lite
::
TargetWrapperCuda
::
GetComputeCapability
()
>=
70
)
{
if
(
frame_size
<
16
)
{
constexpr
int
tiled_size
=
8
;
int
frame_blocks
=
(
frame_size
*
2
+
tiled_size
-
1
)
/
tiled_size
;
threads
=
dim3
(
tiled_size
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruGate
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
prev_out_value
,
value
.
gate_weight
,
value
.
reset_output_value
,
frame_size
,
active_gate
);
frame_blocks
=
(
frame_size
+
tiled_size
-
1
)
/
tiled_size
;
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruOut
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
state_weight
,
value
.
prev_out_value
,
value
.
output_value
,
value
.
gate_value
,
value
.
reset_output_value
,
frame_size
,
active_node
,
origin_mode
);
}
else
{
constexpr
int
tiled_size
=
16
;
int
frame_blocks
=
(
frame_size
*
2
+
tiled_size
-
1
)
/
tiled_size
;
threads
=
dim3
(
tiled_size
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruGate
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
prev_out_value
,
value
.
gate_weight
,
value
.
reset_output_value
,
frame_size
,
active_gate
);
frame_blocks
=
(
frame_size
+
tiled_size
-
1
)
/
tiled_size
;
grids
=
dim3
(
frame_blocks
,
1
);
lite
::
cuda
::
math
::
FastCollectiveGruOut
<
T
,
tiled_size
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
state_weight
,
value
.
prev_out_value
,
value
.
output_value
,
value
.
gate_value
,
value
.
reset_output_value
,
frame_size
,
active_node
,
origin_mode
);
}
return
;
}
else
{
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
int
frame_blocks
=
(
frame_size
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame_per_block
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
}
}
else
{
threads
=
dim3
(
32
,
32
);
grids
=
dim3
((
frame_size
+
32
-
1
)
/
32
,
(
batch_size
+
32
-
1
)
/
32
);
...
...
@@ -121,6 +180,90 @@ struct GRUUnitFunctor {
template
struct
GRUUnitFunctor
<
float
>;
template
<
>
struct
GRUUnitFunctor
<
half
>
{
static
void
compute
(
GRUMetaValue
<
half
>
value
,
int
frame_size
,
int
batch_size
,
const
lite
::
cuda
::
math
::
ActivationType
&
active_node
,
const
lite
::
cuda
::
math
::
ActivationType
&
active_gate
,
bool
origin_mode
,
lite
::
cuda
::
math
::
Gemm
<
half
,
half
>*
blas
,
CUDAContext
*
context
)
{
dim3
threads
,
grids
;
if
(
batch_size
==
1
)
{
int
frame_per_block
=
frame_size
<=
1024
?
frame_size
:
1024
;
int
frame_blocks
=
(
frame_size
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame_per_block
,
1
);
grids
=
dim3
(
frame_blocks
,
1
);
}
else
{
threads
=
dim3
(
32
,
32
);
grids
=
dim3
((
frame_size
+
32
-
1
)
/
32
,
(
batch_size
+
32
-
1
)
/
32
);
}
if
(
value
.
prev_out_value
)
{
CHECK
(
blas
->
init
(
false
,
false
,
batch_size
,
frame_size
*
2
,
frame_size
,
frame_size
,
frame_size
*
2
,
frame_size
*
3
,
context
));
blas
->
run
(
1.0
f
,
1.0
f
,
value
.
prev_out_value
,
value
.
gate_weight
,
value
.
gate_value
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardResetOutput
<
half
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
reset_output_value
,
value
.
prev_out_value
,
frame_size
,
batch_size
,
active_gate
,
batch_size
==
1
);
CUDA_POST_KERNEL_CHECK
;
if
(
value
.
prev_out_value
)
{
CHECK
(
blas
->
init
(
false
,
false
,
batch_size
,
frame_size
,
frame_size
,
frame_size
,
frame_size
,
frame_size
*
3
,
context
));
blas
->
run
(
1.0
f
,
1.0
f
,
value
.
reset_output_value
,
value
.
state_weight
,
value
.
gate_value
+
frame_size
*
2
,
context
);
}
CUDA_POST_KERNEL_CHECK
;
lite
::
cuda
::
math
::
GruForwardFinalOutput
<
half
><<<
grids
,
threads
,
0
,
context
->
exec_stream
()
>>>
(
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
batch_size
,
active_node
,
origin_mode
,
batch_size
==
1
);
CUDA_POST_KERNEL_CHECK
;
}
};
template
<
typename
T
,
PrecisionType
PType
>
void
GRUCompute
<
T
,
PType
>::
PrepareForRun
()
{
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>
);
...
...
@@ -141,18 +284,17 @@ void GRUCompute<T, PType>::Run() {
if
(
param
.
bias
)
{
bias
=
const_cast
<
lite
::
Tensor
*>
(
param
.
bias
);
}
auto
*
weight
=
param
.
weight
;
auto
*
weight_data
=
const_cast
<
T
*>
(
weight
->
template
data
<
T
>());
auto
*
batch_gate
=
param
.
batch_gate
;
auto
*
batch_reset_hidden_prev
=
param
.
batch_reset_hidden_prev
;
auto
*
batch_hidden
=
param
.
batch_hidden
;
auto
*
hidden
=
param
.
hidden
;
auto
*
batch_reset_hidden_prev_data
=
const
lite
::
Tensor
*
weight
=
param
.
weight
;
T
*
weight_data
=
const_cast
<
T
*>
(
weight
->
template
data
<
T
>());
lite
::
Tensor
*
batch_gate
=
param
.
batch_gate
;
lite
::
Tensor
*
batch_reset_hidden_prev
=
param
.
batch_reset_hidden_prev
;
lite
::
Tensor
*
batch_hidden
=
param
.
batch_hidden
;
lite
::
Tensor
*
hidden
=
param
.
hidden
;
T
*
batch_reset_hidden_prev_data
=
batch_reset_hidden_prev
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
auto
*
batch_gate_data
=
batch_gate
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
auto
*
batch_hidden_data
=
batch_hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
T
*
batch_gate_data
=
batch_gate
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
T
*
batch_hidden_data
=
batch_hidden
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
bool
is_reverse
=
param
.
is_reverse
;
auto
active_node
=
lite
::
cuda
::
math
::
GetActiveType
(
param
.
activation
);
auto
active_gate
=
lite
::
cuda
::
math
::
GetActiveType
(
param
.
gate_activation
);
...
...
@@ -224,6 +366,8 @@ void GRUCompute<T, PType>::Run() {
using
GRUFp32
=
paddle
::
lite
::
kernels
::
cuda
::
GRUCompute
<
float
,
PRECISION
(
kFloat
)
>
;
using
GRUFp16
=
paddle
::
lite
::
kernels
::
cuda
::
GRUCompute
<
half
,
PRECISION
(
kFP16
)
>
;
REGISTER_LITE_KERNEL
(
gru
,
kCUDA
,
kFloat
,
kNCHW
,
GRUFp32
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"H0"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
...
...
@@ -234,3 +378,20 @@ REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.
BindOutput
(
"BatchHidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Hidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
gru
,
kCUDA
,
kFP16
,
kNCHW
,
GRUFp16
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"H0"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"Weight"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"BatchGate"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"BatchResetHiddenPrev"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"BatchHidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Hidden"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/kernels/cuda/gru_compute_test.cc
浏览文件 @
8eeaa0ac
...
...
@@ -45,10 +45,13 @@ class GRUTest : public ::testing::Test {
x_ref_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_gpu_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_ref_
.
set_lod
(
lod_
);
w_ref_
.
Resize
(
lite
::
DDim
(
w_shape_
));
w_gpu_
.
Resize
(
lite
::
DDim
(
w_shape_
));
auto
x_ref_data
=
x_ref_
.
mutable_data
<
float
>
();
auto
w_ref_data
=
w_ref_
.
mutable_data
<
float
>
();
for
(
int64_t
i
=
0
;
i
<
x_ref_
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
...
...
@@ -63,6 +66,7 @@ class GRUTest : public ::testing::Test {
batch_hidden_gpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
batch_reset_hidden_gpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
RunBaseLine
();
InitParamAndContext
();
}
...
...
@@ -91,6 +95,22 @@ class GRUTest : public ::testing::Test {
w_gpu_
.
dims
());
}
void
InitHalfInput
()
{
x_half_
.
Resize
(
lite
::
DDim
(
x_shape_
));
auto
x_half_data
=
x_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
x_half_
.
numel
();
i
++
)
{
x_half_data
[
i
]
=
half
(
lite
::
float16
(
x_ref_
.
data
<
float
>
()[
i
]));
}
x_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half_data
,
x_gpu_
.
dims
());
x_gpu_
.
set_lod
(
x_ref_
.
lod
());
w_half_
.
Resize
(
w_ref_
.
dims
());
auto
w_half_data
=
w_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
w_half_
.
numel
();
i
++
)
{
w_half_data
[
i
]
=
half
(
lite
::
float16
(
w_ref_
.
data
<
float
>
()[
i
]));
}
w_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
w_half_data
,
w_gpu_
.
dims
());
}
void
RunBaseLine
()
{}
int
batch_
,
frame_size_
;
...
...
@@ -134,6 +154,29 @@ TEST_F(GRUTest, TestFP32) {
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
}
TEST_F
(
GRUTest
,
TestFP16
)
{
InitHalfInput
();
GRUCompute
<
half
,
PRECISION
(
kFP16
)
>
kernel
;
kernel
.
SetParam
(
param_
);
kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp16, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
...
...
lite/kernels/cuda/sigmoid_compute.cu
0 → 100644
浏览文件 @
8eeaa0ac
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sigmoid_compute.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
,
PrecisionType
Ptype
>
void
SigmoidCompute
<
T
,
Ptype
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
param_t
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
int
num
=
static_cast
<
int
>
(
param
.
X
->
numel
());
auto
input
=
param
.
X
->
template
data
<
T
>();
auto
output
=
param
.
Out
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
lite
::
cuda
::
math
::
sigmoid
<
T
>
(
num
,
input
,
output
,
stream
);
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
using
SigmoidFp32
=
paddle
::
lite
::
kernels
::
cuda
::
SigmoidCompute
<
float
,
PRECISION
(
kFloat
)
>
;
using
SigmoidFp16
=
paddle
::
lite
::
kernels
::
cuda
::
SigmoidCompute
<
half
,
PRECISION
(
kFP16
)
>
;
REGISTER_LITE_KERNEL
(
sigmoid
,
kCUDA
,
kFloat
,
kNCHW
,
SigmoidFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
sigmoid
,
kCUDA
,
kFP16
,
kNCHW
,
SigmoidFp16
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/kernels/cuda/sigmoid_compute.h
0 → 100644
浏览文件 @
8eeaa0ac
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
,
PrecisionType
Ptype
>
class
SigmoidCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
Ptype
>
{
public:
using
param_t
=
operators
::
ActivationParam
;
void
Run
()
override
;
virtual
~
SigmoidCompute
()
=
default
;
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/sigmoid_compute_test.cc
0 → 100644
浏览文件 @
8eeaa0ac
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sigmoid_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/utils/float16.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
SigmoidTest
:
public
::
testing
::
Test
{
protected:
SigmoidTest
()
:
m_
(
8
),
n_
(
64
),
shape_
({
m_
,
n_
})
{
x_ref_
.
Resize
(
lite
::
DDim
(
shape_
));
x_gpu_
.
Resize
(
lite
::
DDim
(
shape_
));
auto
x_ref_data
=
x_ref_
.
mutable_data
<
float
>
();
for
(
int64_t
i
=
0
;
i
<
x_ref_
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
out_ref_
.
Resize
(
lite
::
DDim
(
shape_
));
out_cpu_
.
Resize
(
out_ref_
.
dims
());
out_gpu_
.
Resize
(
out_ref_
.
dims
());
RunBaseLine
();
InitParamAndContext
();
}
void
InitParamAndContext
()
{
ctx_
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream_
);
auto
&
context
=
ctx_
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream_
);
param_
.
X
=
&
x_gpu_
;
param_
.
Out
=
&
out_gpu_
;
}
void
InitFloatInput
()
{
x_gpu_
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_ref_
.
data
<
float
>
(),
x_gpu_
.
dims
());
}
void
InitHalfInput
()
{
x_half_
.
Resize
(
lite
::
DDim
(
shape_
));
auto
x_half_data
=
x_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
x_half_
.
numel
();
i
++
)
{
x_half_data
[
i
]
=
half
(
lite
::
float16
(
x_ref_
.
data
<
float
>
()[
i
]));
}
x_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half_data
,
x_gpu_
.
dims
());
}
void
RunBaseLine
()
{
for
(
int64_t
i
=
0
;
i
<
x_ref_
.
numel
();
++
i
)
{
out_ref_
.
mutable_data
<
float
>
()[
i
]
=
1.
f
/
(
1.
f
+
expf
(
-
1
*
x_ref_
.
data
<
float
>
()[
i
]));
}
}
int
m_
,
n_
;
std
::
vector
<
int64_t
>
shape_
;
lite
::
Tensor
x_ref_
,
out_ref_
;
lite
::
Tensor
x_gpu_
;
lite
::
Tensor
x_half_
;
lite
::
Tensor
out_cpu_
,
out_gpu_
;
operators
::
ActivationParam
param_
;
std
::
unique_ptr
<
KernelContext
>
ctx_
;
cudaStream_t
stream_
;
};
TEST_F
(
SigmoidTest
,
TestFP32
)
{
InitFloatInput
();
SigmoidCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param_
);
kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp32, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_
.
mutable_data
<
float
>
(),
out_gpu_
.
data
<
float
>
(),
sizeof
(
float
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_gpu_
.
numel
();
++
i
)
{
float
res
=
out_cpu_
.
data
<
float
>
()[
i
];
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
ref
,
0.
f
,
1e-5
);
}
}
TEST_F
(
SigmoidTest
,
TestFP16
)
{
InitHalfInput
();
SigmoidCompute
<
half
,
PRECISION
(
kFP16
)
>
kernel
;
kernel
.
SetParam
(
param_
);
kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp16, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
const
half
*
out_gpu_data
=
out_gpu_
.
data
<
half
>
();
half
*
out_cpu_data
=
out_cpu_
.
mutable_data
<
half
>
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_gpu_data
,
sizeof
(
half
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_gpu_
.
numel
();
++
i
)
{
float
res
=
static_cast
<
float
>
(
lite
::
float16
(
out_cpu_data
[
i
]));
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
(
ref
+
1e-5
),
0.
,
2e-2
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/tools/build.sh
浏览文件 @
8eeaa0ac
...
...
@@ -39,8 +39,8 @@ readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/t
readonly
workspace
=
$PWD
# if operating in mac env, we should expand the maximum file num
os_n
ma
e
=
`
uname
-s
`
if
[
${
os_n
ma
e
}
==
"Darwin"
]
;
then
os_n
am
e
=
`
uname
-s
`
if
[
${
os_n
am
e
}
==
"Darwin"
]
;
then
ulimit
-n
1024
fi
...
...
lite/tools/ci_build.sh
浏览文件 @
8eeaa0ac
...
...
@@ -21,8 +21,8 @@ USE_ADB_EMULATOR=ON
LITE_WITH_COVERAGE
=
OFF
# if operating in mac env, we should expand the maximum file num
os_n
ma
e
=
`
uname
-s
`
if
[
${
os_n
ma
e
}
==
"Darwin"
]
;
then
os_n
am
e
=
`
uname
-s
`
if
[
${
os_n
am
e
}
==
"Darwin"
]
;
then
ulimit
-n
1024
fi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录