Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
498f147d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
498f147d
编写于
7月 08, 2020
作者:
W
Wilber
提交者:
GitHub
7月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CUDA] [Kernel] Add matmul cuda kernel. (#3897)
上级
4776f8f4
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
846 addition
and
15 deletion
+846
-15
CMakeLists.txt
CMakeLists.txt
+2
-1
cmake/cuda.cmake
cmake/cuda.cmake
+8
-0
lite/backends/cuda/math/CMakeLists.txt
lite/backends/cuda/math/CMakeLists.txt
+2
-0
lite/backends/cuda/math/strided_gemm.cc
lite/backends/cuda/math/strided_gemm.cc
+136
-0
lite/backends/cuda/math/strided_gemm.h
lite/backends/cuda/math/strided_gemm.h
+72
-0
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+2
-0
lite/kernels/cuda/fc_compute.cu
lite/kernels/cuda/fc_compute.cu
+178
-1
lite/kernels/cuda/fc_compute_test.cc
lite/kernels/cuda/fc_compute_test.cc
+38
-2
lite/kernels/cuda/matmul_compute.cc
lite/kernels/cuda/matmul_compute.cc
+156
-0
lite/kernels/cuda/matmul_compute.h
lite/kernels/cuda/matmul_compute.h
+50
-0
lite/kernels/cuda/matmul_compute_test.cc
lite/kernels/cuda/matmul_compute_test.cc
+193
-0
lite/kernels/cuda/transpose_compute_test.cc
lite/kernels/cuda/transpose_compute_test.cc
+9
-11
未找到文件。
CMakeLists.txt
浏览文件 @
498f147d
...
...
@@ -106,7 +106,8 @@ lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kerne
lite_option
(
LITE_BUILD_TAILOR
"Enable tailoring library according to model"
OFF
)
# cv build options
lite_option
(
LITE_WITH_CV
"Enable build cv image in lite"
OFF
)
lite_option
(
LITE_WITH_STATIC_CUDA
"Statically link cuda libraries."
ON
)
lite_option
(
LITE_WITH_STATIC_CUDA
"Statically link cuda libraries."
OFF
)
lite_option
(
CUDA_WITH_FP16
"Compile with cuda half support"
OFF
)
lite_option
(
LITE_WITH_ARM_CLANG
"when arm lang is clang, its ON."
OFF
)
# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter.
...
...
cmake/cuda.cmake
浏览文件 @
498f147d
...
...
@@ -2,6 +2,10 @@ if(NOT LITE_WITH_CUDA)
return
()
endif
()
if
(
WITH_CUDA_FP16
)
add_definitions
(
"-DCUDA_WITH_FP16"
)
endif
()
set
(
paddle_known_gpu_archs
"30 35 50 52 60 61 70"
)
set
(
paddle_known_gpu_archs7
"30 35 50 52"
)
set
(
paddle_known_gpu_archs8
"30 35 50 52 53 60 61 62"
)
...
...
@@ -167,6 +171,10 @@ elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
add_definitions
(
"-DPADDLE_CUDA_BINVER=
\"
100
\"
"
)
endif
()
if
(
CUDA_WITH_FP16
)
STRING
(
REGEX REPLACE
"30|35|50|52"
""
paddle_known_gpu_archs
${
paddle_known_gpu_archs
}
)
endif
()
include_directories
(
${
CUDA_INCLUDE_DIRS
}
)
if
(
NOT WITH_DSO
)
if
(
WIN32
)
...
...
lite/backends/cuda/math/CMakeLists.txt
浏览文件 @
498f147d
...
...
@@ -13,6 +13,7 @@ nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library
(
cudnn_pool SRCS cudnn_pool.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_gemm SRCS gemm.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_batched_gemm SRCS batched_gemm.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_strided_gemm SRCS strided_gemm.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_sequence_padding SRCS sequence_padding.cu DEPS
${
cuda_static_deps
}
)
set
(
...
...
@@ -26,6 +27,7 @@ set (
cudnn_pool
cuda_gemm
cuda_batched_gemm
cuda_strided_gemm
cuda_sequence_padding
)
...
...
lite/backends/cuda/math/strided_gemm.cc
0 → 100644
浏览文件 @
498f147d
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/strided_gemm.h"
#include <iostream>
#include "lite/core/device_info.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
template
<
typename
PtypeIn
,
typename
PtypeOut
>
bool
StridedGemm
<
PtypeIn
,
PtypeOut
>::
init
(
const
bool
trans_a
,
const
bool
trans_b
,
Context
<
TARGET
(
kCUDA
)
>*
ctx
)
{
if
(
cu_handle_
==
nullptr
)
{
this
->
exe_stream_
=
ctx
->
exec_stream
();
CUBLAS_CALL
(
cublasCreate
(
&
cu_handle_
));
CUBLAS_CALL
(
cublasSetStream
(
cu_handle_
,
this
->
exe_stream_
));
}
cu_trans_a_
=
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cu_trans_b_
=
trans_b
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
return
true
;
}
template
<
>
bool
StridedGemm
<
float
,
float
>::
run
(
const
float
alpha
,
const
float
beta
,
const
int
m
,
const
int
n
,
const
int
k
,
const
float
*
a_data
,
const
float
*
b_data
,
float
*
c_data
,
const
int
batch_size
,
const
int64_t
stride_a
,
const
int64_t
stride_b
)
{
lda_
=
(
cu_trans_a_
==
CUBLAS_OP_N
)
?
k
:
m
;
ldb_
=
(
cu_trans_b_
==
CUBLAS_OP_N
)
?
n
:
k
;
ldc_
=
n
;
m_
=
m
;
n_
=
n
;
k_
=
k
;
const
int64_t
stride_c
=
m_
*
n_
;
CUBLAS_CALL
(
cublasGemmStridedBatchedEx
(
cu_handle_
,
cu_trans_b_
,
cu_trans_a_
,
n_
,
m_
,
k_
,
&
alpha
,
b_data
,
CUDA_R_32F
,
ldb_
,
stride_b
,
a_data
,
CUDA_R_32F
,
lda_
,
stride_a
,
&
beta
,
c_data
,
CUDA_R_32F
,
ldc_
,
stride_c
,
batch_size
,
CUDA_R_32F
,
algo_
));
return
true
;
}
template
<
>
bool
StridedGemm
<
half
,
half
>::
run
(
const
half
alpha
,
const
half
beta
,
const
int
m
,
const
int
n
,
const
int
k
,
const
half
*
a_data
,
const
half
*
b_data
,
half
*
c_data
,
const
int
batch_size
,
const
int64_t
stride_a
,
const
int64_t
stride_b
)
{
lda_
=
(
cu_trans_a_
==
CUBLAS_OP_N
)
?
k
:
m
;
ldb_
=
(
cu_trans_b_
==
CUBLAS_OP_N
)
?
n
:
k
;
ldc_
=
n
;
m_
=
m
;
n_
=
n
;
k_
=
k
;
const
int64_t
stride_c
=
m_
*
n_
;
CUBLAS_CALL
(
cublasGemmStridedBatchedEx
(
cu_handle_
,
cu_trans_b_
,
cu_trans_a_
,
n_
,
m_
,
k_
,
&
alpha
,
b_data
,
CUDA_R_16F
,
ldb_
,
stride_b
,
a_data
,
CUDA_R_16F
,
lda_
,
stride_a
,
&
beta
,
c_data
,
CUDA_R_16F
,
ldc_
,
stride_c
,
batch_size
,
CUDA_R_16F
,
algo_
));
return
true
;
}
template
class
StridedGemm
<
float
,
float
>;
template
class
StridedGemm
<
half
,
half
>;
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/backends/cuda/math/strided_gemm.h
0 → 100644
浏览文件 @
498f147d
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
template
<
typename
PtypeIn
,
typename
PtypeOut
>
class
StridedGemm
{
public:
StridedGemm
()
:
cu_handle_
(
nullptr
)
{}
~
StridedGemm
()
{}
bool
init
(
const
bool
trans_a
,
const
bool
trans_b
,
Context
<
TARGET
(
kCUDA
)
>*
ctx
);
bool
run
(
const
PtypeIn
alpha
,
const
PtypeIn
beta
,
const
int
m
,
const
int
n
,
const
int
k
,
const
PtypeIn
*
a_data
,
const
PtypeIn
*
b_data
,
PtypeOut
*
c_data
,
const
int
batch_size
,
const
int64_t
stride_a
,
const
int64_t
stride_b
);
private:
cudaStream_t
exe_stream_
;
cublasHandle_t
cu_handle_
;
cublasOperation_t
cu_trans_a_
;
cublasOperation_t
cu_trans_b_
;
int
m_
{
-
1
};
int
n_
{
-
1
};
int
k_
{
-
1
};
int
lda_
{
-
1
};
int
ldb_
{
-
1
};
int
ldc_
{
-
1
};
cublasGemmAlgo_t
algo_
{
CUBLAS_GEMM_DEFAULT_TENSOR_OP
};
};
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
498f147d
...
...
@@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel
(
mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
matmul_compute_cuda CUDA basic SRCS matmul_compute.cc DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS
${
lite_kernel_deps
}
)
add_kernel
(
leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS
${
lite_kernel_deps
}
)
...
...
@@ -68,6 +69,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test
(
mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda
)
nv_test
(
fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda
)
nv_test
(
matmul_compute_cuda_test SRCS matmul_compute_test.cc DEPS matmul_compute_cuda
)
nv_test
(
dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda
)
nv_test
(
bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda
)
#nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
...
...
lite/kernels/cuda/fc_compute.cu
浏览文件 @
498f147d
...
...
@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/fc_compute.h"
#include <string>
...
...
@@ -32,6 +31,74 @@ struct FcTypeTraits<float> {
typedef
float4
Type
;
};
template
<
typename
T
>
__global__
void
AddBiasV2
(
const
int
num
,
const
T
*
bias
,
T
*
data
,
int
K
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
int
bias_idx
=
index
%
K
;
const
T
bias_ptr
=
bias
[
bias_idx
];
const
T
in_ptr
=
data
[
index
];
T
packed_val
;
packed_val
.
x
=
in_ptr
.
x
+
bias_ptr
.
x
;
packed_val
.
y
=
in_ptr
.
y
+
bias_ptr
.
y
;
data
[
index
]
=
packed_val
;
}
}
template
<
>
__global__
void
AddBiasV2
(
const
int
num
,
const
half2
*
bias
,
half2
*
data
,
int
K
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
int
bias_idx
=
index
%
K
;
const
half2
bias_ptr
=
bias
[
bias_idx
];
const
half2
in_ptr
=
data
[
index
];
#if __CUDA_ARCH__ >= 530
data
[
index
]
=
__hadd2
(
in_ptr
,
bias_ptr
);
#else
half2
packed_val
;
packed_val
.
x
=
__hadd
(
in_ptr
.
x
,
bias_ptr
.
x
);
packed_val
.
y
=
__hadd
(
in_ptr
.
y
,
bias_ptr
.
y
);
data
[
index
]
=
packed_val
;
#endif
}
}
template
<
typename
T
>
__global__
void
AddBiasReluV2
(
const
int
num
,
const
T
*
bias
,
T
*
data
,
int
K
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
int
bias_idx
=
index
%
K
;
const
T
bias_ptr
=
bias
[
bias_idx
];
const
T
in_ptr
=
data
[
index
];
T
packed_val
;
packed_val
.
x
=
fmaxf
(
0.
f
,
in_ptr
.
x
+
bias_ptr
.
x
);
packed_val
.
y
=
fmaxf
(
0.
f
,
in_ptr
.
y
+
bias_ptr
.
y
);
data
[
index
]
=
packed_val
;
}
}
template
<
>
__global__
void
AddBiasReluV2
(
const
int
num
,
const
half2
*
bias
,
half2
*
data
,
int
K
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
int
bias_idx
=
index
%
K
;
const
half2
bias_ptr
=
bias
[
bias_idx
];
const
half2
in_ptr
=
data
[
index
];
#if __CUDA_ARCH__ >= 530
data
[
index
]
=
__hmul2
(
__hgt2
(
in_ptr
+
bias_ptr
,
__float2half2_rn
(
0.
f
)),
in_ptr
+
bias_ptr
);
#else
const
float2
bias
=
__half22float2
(
bias_ptr
);
const
float2
in
=
__half22float2
(
in_ptr
);
data
[
index
]
=
__floats2half2_rn
(
bias
.
x
+
in
.
x
>
0.0
f
?
static_cast
<
float
>
(
bias
.
x
+
in
.
x
)
:
0.0
f
,
bias
.
y
+
in
.
y
>
0.0
f
?
static_cast
<
float
>
(
bias
.
y
+
in
.
y
)
:
0.0
f
);
#endif
}
}
template
<
typename
T
>
__global__
void
AddBiasV4
(
const
int
num
,
const
T
*
bias
,
T
*
data
,
int
K
)
{
CUDA_KERNEL_LOOP
(
index
,
num
)
{
...
...
@@ -77,6 +144,21 @@ __global__ void AddBias(const int num, const T* bias, T* data) {
}
}
template
<
>
__global__
void
AddBias
(
const
int
num
,
const
half
*
bias
,
half
*
data
)
{
int
offset
=
blockIdx
.
x
*
num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
)
{
half
temp
;
#if __CUDA_ARCH__ >= 350
temp
=
__hadd
(
__ldg
(
data
+
offset
+
i
),
__ldg
(
bias
+
i
));
#else
temp
=
__hadd
(
data
[
offset
+
i
],
bias
[
i
]);
#endif
data
[
offset
+
i
]
=
temp
;
}
}
template
<
typename
T
>
__global__
void
AddBiasRelu
(
const
int
num
,
const
T
*
bias
,
T
*
data
)
{
int
offset
=
blockIdx
.
x
*
num
;
...
...
@@ -92,6 +174,28 @@ __global__ void AddBiasRelu(const int num, const T* bias, T* data) {
}
}
template
<
>
__global__
void
AddBiasRelu
<
half
>
(
const
int
num
,
const
half
*
bias
,
half
*
data
)
{
int
offset
=
blockIdx
.
x
*
num
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
)
{
half
temp
;
#if __CUDA_ARCH__ >= 350
temp
=
__hadd
(
__ldg
(
data
+
offset
+
i
),
__ldg
(
bias
+
i
));
#else
temp
=
__hadd
(
data
[
offset
+
i
],
bias
[
i
]);
#endif
#if __CUDA_ARCH__ >= 530
data
[
offset
+
i
]
=
__hgt
(
temp
,
__float2half
(
0.0
f
))
?
temp
:
__float2half
(
0.0
f
);
#else
data
[
offset
+
i
]
=
__float2half
(
__half2float
(
temp
)
>
0.
f
?
__half2float
(
temp
)
:
0.
f
);
#endif
}
}
template
<
typename
T
,
PrecisionType
PType
>
void
FcCompute
<
T
,
PType
>::
PrepareForRun
()
{
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>
);
...
...
@@ -161,6 +265,69 @@ void FcCompute<T, PType>::Run() {
}
}
template
<
>
void
FcCompute
<
half
,
PRECISION
(
kFP16
)
>::
Run
()
{
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
context
.
exec_stream
();
auto
&
param
=
this
->
template
Param
<
param_t
>();
const
auto
*
x_data
=
param
.
input
->
template
data
<
half
>();
const
auto
*
w_data
=
param
.
w
->
template
data
<
half
>();
const
auto
*
b_data
=
param
.
bias
?
param
.
bias
->
template
data
<
half
>()
:
nullptr
;
auto
out_vec
=
param
.
output
->
dims
().
Vectorize
();
out_vec
.
back
()
=
param
.
w
->
dims
()[
1
];
param
.
output
->
Resize
(
out_vec
);
auto
*
out_data
=
param
.
output
->
template
mutable_data
<
half
>(
TARGET
(
kCUDA
));
int
in_num_col_dims
=
param
.
in_num_col_dims
;
int
M
=
static_cast
<
int
>
(
param
.
input
->
dims
().
Slice
(
0
,
param
.
in_num_col_dims
).
production
());
int
K
=
static_cast
<
int
>
(
param
.
input
->
dims
()
.
Slice
(
param
.
in_num_col_dims
,
param
.
input
->
dims
().
size
())
.
production
());
int
K2
=
static_cast
<
int
>
(
param
.
w
->
dims
()[
0
]);
int
N
=
static_cast
<
int
>
(
param
.
w
->
dims
()[
1
]);
CHECK_EQ
(
K
,
K2
)
<<
"x_w must be equal with y_h"
;
CHECK
(
gemm_impl_
->
init
(
false
,
false
,
M
,
N
,
K
,
&
context
));
gemm_impl_
->
run
(
1.0
f
,
0.0
f
,
x_data
,
w_data
,
out_data
,
&
context
);
if
(
b_data
==
nullptr
)
{
return
;
}
std
::
string
activation_type
=
param
.
activation_type
;
if
(
N
%
2
==
0
)
{
const
int
threads
=
256
;
const
int
num
=
M
*
N
/
2
;
const
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
const
auto
*
bias_ptr_v2
=
reinterpret_cast
<
const
half2
*>
(
b_data
);
auto
*
data_ptr_v2
=
reinterpret_cast
<
half2
*>
(
out_data
);
if
(
activation_type
==
"relu"
)
{
AddBiasReluV2
<
half2
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v2
,
data_ptr_v2
,
N
/
2
);
}
else
if
(
activation_type
==
""
)
{
AddBiasV2
<
half2
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v2
,
data_ptr_v2
,
N
/
2
);
}
else
{
LOG
(
FATAL
)
<<
"not supported activation type: "
<<
activation_type
;
}
}
else
{
const
int
threads
=
256
;
const
int
blocks
=
M
;
if
(
activation_type
==
"relu"
)
{
AddBiasRelu
<
half
><<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
b_data
,
out_data
);
}
else
if
(
activation_type
==
""
)
{
AddBias
<
half
><<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
b_data
,
out_data
);
}
else
{
LOG
(
FATAL
)
<<
"not supported activation type: "
<<
activation_type
;
}
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
...
...
@@ -168,9 +335,19 @@ void FcCompute<T, PType>::Run() {
using
FcFp32
=
paddle
::
lite
::
kernels
::
cuda
::
FcCompute
<
float
,
PRECISION
(
kFloat
)
>
;
using
FcFp16
=
paddle
::
lite
::
kernels
::
cuda
::
FcCompute
<
half
,
PRECISION
(
kFP16
)
>
;
REGISTER_LITE_KERNEL
(
fc
,
kCUDA
,
kFloat
,
kNCHW
,
FcFp32
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"W"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
fc
,
kCUDA
,
kFP16
,
kNCHW
,
FcFp16
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"W"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/kernels/cuda/fc_compute_test.cc
浏览文件 @
498f147d
...
...
@@ -31,8 +31,8 @@ namespace cuda {
class
FcTest
:
public
::
testing
::
Test
{
protected:
FcTest
()
:
m_
(
12
8
),
k_
(
512
),
:
m_
(
8
),
k_
(
16
),
n_
(
64
),
in_num_col_dims_
(
1
),
act_type_
(
"relu"
),
...
...
@@ -189,6 +189,42 @@ TEST_F(FcTest, TestFP32) {
}
}
TEST_F
(
FcTest
,
TestFP16
)
{
InitHalfInput
();
FcCompute
<
half
,
PRECISION
(
kFP16
)
>
kernel
;
kernel
.
SetParam
(
param_
);
kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp16, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
const
half
*
out_gpu_data
=
out_gpu_
.
data
<
half
>
();
half
*
out_cpu_data
=
out_cpu_
.
mutable_data
<
half
>
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_gpu_data
,
sizeof
(
half
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_gpu_
.
numel
();
++
i
)
{
float
res
=
static_cast
<
float
>
(
lite
::
float16
(
out_cpu_data
[
i
]));
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
(
ref
+
1e-5
),
0.
,
2e-2
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
...
...
lite/kernels/cuda/matmul_compute.cc
0 → 100644
浏览文件 @
498f147d
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/matmul_compute.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
,
PrecisionType
PType
>
void
MatMulCompute
<
T
,
PType
>::
Run
()
{
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
param
=
this
->
template
Param
<
param_t
>();
const
auto
*
x_data
=
param
.
X
->
template
data
<
T
>();
const
auto
*
y_data
=
param
.
Y
->
template
data
<
T
>();
auto
*
out_data
=
param
.
Out
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
bool
transpose_x
=
param
.
transpose_X
;
bool
transpose_y
=
param
.
transpose_Y
;
float
alpha
=
param
.
alpha
;
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
int
m
=
0
;
int
k
=
0
;
int
n
=
0
;
int
batch
=
0
;
int64_t
stride_x
=
0
;
int64_t
stride_y
=
0
;
if
(
x_dims
.
size
()
>=
2
&&
y_dims
.
size
()
>=
2
&&
(
x_dims
.
size
()
!=
2
||
y_dims
.
size
()
!=
2
))
{
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
// or
// x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [M, K], y: [B, K, N], out: [B, M, N]
strided_gemm_impl_
->
init
(
transpose_x
,
transpose_y
,
&
context
);
m
=
transpose_x
?
x_dims
[
x_dims
.
size
()
-
1
]
:
x_dims
[
x_dims
.
size
()
-
2
];
k
=
transpose_x
?
x_dims
[
x_dims
.
size
()
-
2
]
:
x_dims
[
x_dims
.
size
()
-
1
];
n
=
transpose_y
?
y_dims
[
y_dims
.
size
()
-
2
]
:
y_dims
[
y_dims
.
size
()
-
1
];
int
batch_x
=
x_dims
.
size
()
==
2
?
0
:
x_dims
.
count
(
0
,
x_dims
.
size
()
-
2
);
int
batch_y
=
y_dims
.
size
()
==
2
?
0
:
y_dims
.
count
(
0
,
y_dims
.
size
()
-
2
);
CHECK
(
batch_x
==
batch_y
||
batch_x
==
0
||
batch_y
==
0
)
<<
"batch_size x should be equal to batch_size y, or "
"one of batch_size x and batch_size y should be 0. "
"But got batch_size x = "
<<
batch_x
<<
", batch_size y = "
<<
batch_y
;
batch
=
batch_x
==
0
?
batch_y
:
batch_x
;
stride_x
=
x_dims
.
size
()
==
2
?
0
:
m
*
k
;
stride_y
=
y_dims
.
size
()
==
2
?
0
:
k
*
n
;
strided_gemm_impl_
->
run
(
alpha
,
0.
f
,
m
,
n
,
k
,
x_data
,
y_data
,
out_data
,
batch
,
stride_x
,
stride_y
);
}
else
if
(
x_dims
.
size
()
==
2
&&
y_dims
.
size
()
==
2
)
{
// x: [M, K], y: [K, N], out: [M, N]
m
=
transpose_x
?
x_dims
[
1
]
:
x_dims
[
0
];
k
=
transpose_x
?
x_dims
[
0
]
:
x_dims
[
1
];
n
=
transpose_y
?
y_dims
[
0
]
:
y_dims
[
1
];
gemm_impl_
->
init
(
transpose_x
,
transpose_y
,
m
,
n
,
k
,
&
context
);
gemm_impl_
->
run
(
alpha
,
0.0
f
,
x_data
,
y_data
,
out_data
,
&
context
);
}
else
if
(
x_dims
.
size
()
>
2
&&
y_dims
.
size
()
==
1
)
{
// x: [B, M, K], y: [K], out: [B, M]
strided_gemm_impl_
->
init
(
transpose_x
,
transpose_y
,
&
context
);
m
=
transpose_x
?
x_dims
[
x_dims
.
size
()
-
1
]
:
x_dims
[
x_dims
.
size
()
-
2
];
k
=
transpose_x
?
x_dims
[
x_dims
.
size
()
-
2
]
:
x_dims
[
x_dims
.
size
()
-
1
];
n
=
1
;
batch
=
x_dims
.
count
(
0
,
x_dims
.
size
()
-
2
);
stride_x
=
m
*
k
;
stride_y
=
0
;
strided_gemm_impl_
->
run
(
alpha
,
0.
f
,
m
,
n
,
k
,
x_data
,
y_data
,
out_data
,
batch
,
stride_x
,
stride_y
);
}
else
if
(
x_dims
.
size
()
==
1
&&
y_dims
.
size
()
==
1
)
{
if
(
!
transpose_x
&&
!
transpose_y
)
{
// x: [K], y: [K], out: [1]
m
=
1
;
k
=
x_dims
[
0
];
n
=
1
;
CHECK_EQ
(
x_dims
[
0
],
y_dims
[
0
])
<<
"x_dims[0] should be equal to y_dims[0]"
;
gemm_impl_
->
init
(
false
,
false
,
m
,
n
,
k
,
&
context
);
gemm_impl_
->
run
(
alpha
,
0.0
f
,
x_data
,
y_data
,
out_data
,
&
context
);
}
else
if
(
transpose_x
&&
transpose_y
)
{
// x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N]
m
=
x_dims
[
0
];
k
=
1
;
n
=
y_dims
[
0
];
gemm_impl_
->
init
(
false
,
false
,
m
,
n
,
k
,
&
context
);
gemm_impl_
->
run
(
alpha
,
0.0
f
,
x_data
,
y_data
,
out_data
,
&
context
);
}
else
{
LOG
(
FATAL
)
<<
"not supported x_dims("
<<
x_dims
<<
") and y_dims("
<<
y_dims
<<
"), transpose_x("
<<
transpose_x
<<
"), transpose_y("
<<
transpose_y
<<
")"
;
}
}
else
{
LOG
(
FATAL
)
<<
"not supported x_dims("
<<
x_dims
<<
") and y_dims("
<<
y_dims
<<
")"
;
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
using
MatMulFp32
=
paddle
::
lite
::
kernels
::
cuda
::
MatMulCompute
<
float
,
PRECISION
(
kFloat
)
>
;
using
MatMulFp16
=
paddle
::
lite
::
kernels
::
cuda
::
MatMulCompute
<
half
,
PRECISION
(
kFP16
)
>
;
REGISTER_LITE_KERNEL
(
matmul
,
kCUDA
,
kFloat
,
kNCHW
,
MatMulFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
matmul
,
kCUDA
,
kFP16
,
kNCHW
,
MatMulFp16
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/kernels/cuda/matmul_compute.h
0 → 100644
浏览文件 @
498f147d
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/backends/cuda/math/strided_gemm.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
,
PrecisionType
Ptype
>
class
MatMulCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
Ptype
>
{
public:
using
param_t
=
operators
::
MatMulParam
;
void
PrepareForRun
()
override
{
strided_gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
StridedGemm
<
T
,
T
>
);
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>
);
}
void
Run
()
override
;
virtual
~
MatMulCompute
()
=
default
;
private:
std
::
unique_ptr
<
lite
::
cuda
::
math
::
StridedGemm
<
T
,
T
>>
strided_gemm_impl_
{
nullptr
};
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>>
gemm_impl_
{
nullptr
};
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/matmul_compute_test.cc
0 → 100644
浏览文件 @
498f147d
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/matmul_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
MatMulTest
:
public
::
testing
::
Test
{
protected:
MatMulTest
()
:
x_trans_
(
false
),
y_trans_
(
true
),
alpha_
(
1.0
f
),
x_shape_
({
4
,
1
,
2
}),
y_shape_
({
4
,
1
,
2
}),
out_shape_
({
4
,
1
,
1
})
{
x_ref_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_gpu_
.
Resize
(
x_ref_
.
dims
());
y_ref_
.
Resize
(
lite
::
DDim
(
y_shape_
));
y_gpu_
.
Resize
(
y_ref_
.
dims
());
auto
x_ref_data
=
x_ref_
.
mutable_data
<
float
>
();
auto
y_ref_data
=
y_ref_
.
mutable_data
<
float
>
();
// prepare input
for
(
int64_t
i
=
0
;
i
<
x_ref_
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
1
);
}
for
(
int64_t
i
=
0
;
i
<
y_ref_
.
numel
();
i
++
)
{
y_ref_data
[
i
]
=
static_cast
<
float
>
(
1
);
}
out_ref_
.
Resize
(
lite
::
DDim
(
out_shape_
));
out_cpu_
.
Resize
(
out_ref_
.
dims
());
out_gpu_
.
Resize
(
out_ref_
.
dims
());
RunBaseLine
();
InitParamAndContext
();
}
void
InitParamAndContext
()
{
ctx_
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream_
);
auto
&
context
=
ctx_
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream_
);
param_
.
X
=
&
x_gpu_
;
param_
.
Y
=
&
y_gpu_
;
param_
.
transpose_X
=
x_trans_
;
param_
.
transpose_Y
=
y_trans_
;
param_
.
alpha
=
alpha_
;
param_
.
Out
=
&
out_gpu_
;
}
void
InitFloatInput
()
{
x_gpu_
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_ref_
.
data
<
float
>
(),
x_gpu_
.
dims
());
y_gpu_
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
y_ref_
.
data
<
float
>
(),
y_gpu_
.
dims
());
}
void
InitHalfInput
()
{
x_half_
.
Resize
(
x_ref_
.
dims
());
auto
x_half_data
=
x_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
x_half_
.
numel
();
++
i
)
{
x_half_data
[
i
]
=
half
(
lite
::
float16
(
x_ref_
.
data
<
float
>
()[
i
]));
}
x_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half_data
,
x_gpu_
.
dims
());
y_half_
.
Resize
(
y_ref_
.
dims
());
auto
y_half_data
=
y_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
y_half_
.
numel
();
i
++
)
{
y_half_data
[
i
]
=
half
(
lite
::
float16
(
y_ref_
.
data
<
float
>
()[
i
]));
}
y_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
y_half_data
,
y_gpu_
.
dims
());
}
void
RunBaseLine
()
{
auto
*
out_data
=
out_ref_
.
mutable_data
<
float
>
();
for
(
int64_t
i
=
0
;
i
<
out_ref_
.
numel
();
++
i
)
{
out_data
[
i
]
=
2
;
}
}
bool
x_trans_
,
y_trans_
;
float
alpha_
;
std
::
vector
<
int64_t
>
x_shape_
,
y_shape_
,
out_shape_
;
lite
::
Tensor
x_ref_
,
y_ref_
,
out_ref_
;
lite
::
Tensor
x_gpu_
,
y_gpu_
;
lite
::
Tensor
x_half_
,
y_half_
;
lite
::
Tensor
out_cpu_
,
out_gpu_
;
operators
::
MatMulParam
param_
;
std
::
unique_ptr
<
KernelContext
>
ctx_
;
cudaStream_t
stream_
;
};
TEST_F
(
MatMulTest
,
TestFP32
)
{
InitFloatInput
();
MatMulCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param_
);
kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp32, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_
.
mutable_data
<
float
>
(),
out_gpu_
.
data
<
float
>
(),
sizeof
(
float
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_gpu_
.
numel
();
++
i
)
{
float
res
=
out_cpu_
.
data
<
float
>
()[
i
];
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
ref
,
0.
f
,
1e-5
);
}
}
TEST_F
(
MatMulTest
,
TestFP16
)
{
InitHalfInput
();
MatMulCompute
<
half
,
PRECISION
(
kFP16
)
>
kernel
;
kernel
.
SetParam
(
param_
);
kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp16, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
const
half
*
out_gpu_data
=
out_gpu_
.
data
<
half
>
();
half
*
out_cpu_data
=
out_cpu_
.
mutable_data
<
half
>
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_gpu_data
,
sizeof
(
half
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_gpu_
.
numel
();
++
i
)
{
float
res
=
static_cast
<
float
>
(
lite
::
float16
(
out_cpu_data
[
i
]));
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
(
ref
+
1e-5
),
0.
,
1e-2
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/transpose_compute_test.cc
浏览文件 @
498f147d
...
...
@@ -71,7 +71,7 @@ void Nchw2nhwcBaseLine(lite::Tensor* input,
n * output_h * output_w * output_c]
void
Nhwc2nchwBaseLine
(
lite
::
Tensor
*
input
,
lite
::
Tensor
*
output
,
const
std
::
vector
<
int
>
axies
)
{
const
std
::
vector
<
int
>
&
axies
)
{
auto
*
input_data
=
input
->
data
<
float
>
();
auto
*
output_data
=
output
->
mutable_data
<
float
>
();
...
...
@@ -175,7 +175,6 @@ TEST(transpose_nchw, normal) {
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
Nchw2nhwcBaseLine
(
&
x_ref
,
&
out_ref
,
axes
);
auto
*
out_ref_data
=
out_ref
.
mutable_data
<
float
>
();
// TransBaseLine(&x_ref, &out_ref, axes);
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
EXPECT_NEAR
(
out_cpu_data
[
i
],
out_ref_data
[
i
],
1e-5
);
}
...
...
@@ -226,7 +225,6 @@ TEST(transpose_nhwc, normal) {
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
Nhwc2nchwBaseLine
(
&
x_ref
,
&
out_ref
,
axes
);
// TransBaseLine(&x_ref, &out_ref, axes);
auto
*
out_ref_data
=
out_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
EXPECT_NEAR
(
out_cpu_data
[
i
],
out_ref_data
[
i
],
1e-5
);
...
...
@@ -277,11 +275,11 @@ class TransposeTest : public ::testing::Test {
void
InitHalfInput
()
{
x_half_
.
Resize
(
lite
::
DDim
(
x_ref_
.
dims
()));
auto
X_half_
_data
=
x_half_
.
mutable_data
<
half
>
();
auto
x_half
_data
=
x_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
x_half_
.
numel
();
i
++
)
{
X_half_
_data
[
i
]
=
half
(
lite
::
float16
(
x_ref_
.
data
<
float
>
()[
i
]));
x_half
_data
[
i
]
=
half
(
lite
::
float16
(
x_ref_
.
data
<
float
>
()[
i
]));
}
x_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
X_half_
_data
,
x_gpu_
.
dims
());
x_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half
_data
,
x_gpu_
.
dims
());
}
void
RunBaseLine
(
const
lite
::
Tensor
*
x
,
lite
::
Tensor
*
out
)
{
...
...
@@ -355,15 +353,15 @@ TEST_F(TransposeTest, TestFP16) {
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
const
half
*
Out_gpu_
_data
=
out_gpu_
.
data
<
half
>
();
half
*
Out_cpu_
_data
=
out_cpu_
.
mutable_data
<
half
>
();
CopySync
<
TARGET
(
kCUDA
)
>
(
Out_cpu_
_data
,
Out_gpu_
_data
,
const
half
*
out_gpu
_data
=
out_gpu_
.
data
<
half
>
();
half
*
out_cpu
_data
=
out_cpu_
.
mutable_data
<
half
>
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu
_data
,
out_gpu
_data
,
sizeof
(
half
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_cpu_
.
numel
();
++
i
)
{
float
res
=
static_cast
<
float
>
(
lite
::
float16
(
Out_cpu_
_data
[
i
]));
float
res
=
static_cast
<
float
>
(
lite
::
float16
(
out_cpu
_data
[
i
]));
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
(
ref
+
1e-5
),
0.
,
1e-2
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录