Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
46980d68
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
46980d68
编写于
8月 24, 2020
作者:
W
Wilber
提交者:
GitHub
8月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ltgemm gemv. test=develop (#4155)
上级
89cfa8e6
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
551 addition
and
50 deletion
+551
-50
lite/backends/cuda/cuda_utils.h
lite/backends/cuda/cuda_utils.h
+4
-0
lite/backends/cuda/math/CMakeLists.txt
lite/backends/cuda/math/CMakeLists.txt
+2
-0
lite/backends/cuda/math/gemm.cc
lite/backends/cuda/math/gemm.cc
+150
-0
lite/backends/cuda/math/gemm.h
lite/backends/cuda/math/gemm.h
+73
-1
lite/backends/cuda/math/gemv.cc
lite/backends/cuda/math/gemv.cc
+73
-0
lite/backends/cuda/math/gemv.h
lite/backends/cuda/math/gemv.h
+67
-0
lite/kernels/cuda/mul_compute.cc
lite/kernels/cuda/mul_compute.cc
+12
-2
lite/kernels/cuda/mul_compute.h
lite/kernels/cuda/mul_compute.h
+8
-9
lite/kernels/cuda/mul_compute_test.cc
lite/kernels/cuda/mul_compute_test.cc
+162
-38
未找到文件。
lite/backends/cuda/cuda_utils.h
浏览文件 @
46980d68
...
...
@@ -21,6 +21,10 @@
#include <cudnn.h>
#include "lite/utils/cp_logging.h"
#if (CUBLAS_VER_MAJOR * 10 + CUBLAS_VER_MINOR) >= 101
#include <cublasLt.h>
#endif
/*
* This file contains some CUDA specific utils.
*/
...
...
lite/backends/cuda/math/CMakeLists.txt
浏览文件 @
46980d68
...
...
@@ -16,6 +16,7 @@ nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library
(
cuda_gru_forward SRCS gru_forward.cu DEPS cuda_activation
${
cuda_static_deps
}
)
nv_library
(
cuda_sequence2batch SRCS sequence2batch.cu DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_gemm SRCS gemm.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_gemv SRCS gemv.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_batched_gemm SRCS batched_gemm.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_strided_gemm SRCS strided_gemm.cc DEPS
${
cuda_static_deps
}
)
nv_library
(
cuda_sequence_padding SRCS sequence_padding.cu DEPS
${
cuda_static_deps
}
)
...
...
@@ -35,6 +36,7 @@ set (
cuda_gru_forward
cuda_sequence2batch
cuda_gemm
cuda_gemv
cuda_batched_gemm
cuda_strided_gemm
cuda_sequence_padding
...
...
lite/backends/cuda/math/gemm.cc
浏览文件 @
46980d68
...
...
@@ -123,6 +123,156 @@ bool Gemm<half, half>::run(const half alpha,
template
class
Gemm
<
float
,
float
>;
template
class
Gemm
<
half
,
half
>;
// LtGemm
template
<
typename
T
>
class
cublasTypeWrapper
;
template
<
>
class
cublasTypeWrapper
<
float
>
{
public:
static
const
cudaDataType_t
type
=
CUDA_R_32F
;
};
template
<
>
class
cublasTypeWrapper
<
half
>
{
public:
static
const
cudaDataType_t
type
=
CUDA_R_16F
;
};
#if (CUBLAS_VER_MAJOR * 10 + CUBLAS_VER_MINOR) >= 101
template
<
typename
PTypeIn
,
typename
PTypeOut
>
bool
LtGemm
<
PTypeIn
,
PTypeOut
>::
init
(
const
bool
trans_a
,
const
bool
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
Context
<
TARGET
(
kCUDA
)
>
*
ctx
)
{
int
lda
=
(
!
trans_a
)
?
k
:
m
;
int
ldb
=
(
!
trans_b
)
?
n
:
k
;
int
ldc
=
n
;
return
this
->
init
(
trans_a
,
trans_b
,
m
,
n
,
k
,
lda
,
ldb
,
ldc
,
ctx
);
}
template
<
typename
PTypeIn
,
typename
PTypeOut
>
bool
LtGemm
<
PTypeIn
,
PTypeOut
>::
init
(
const
bool
trans_a
,
const
bool
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
int
lda
,
const
int
ldb
,
const
int
ldc
,
Context
<
TARGET
(
kCUDA
)
>
*
ctx
)
{
if
(
handle_
==
nullptr
)
{
this
->
exe_stream_
=
ctx
->
exec_stream
();
CUBLAS_CALL
(
cublasLtCreate
(
&
handle_
));
}
m_
=
m
;
n_
=
n
;
k_
=
k
;
lda_
=
lda
;
ldb_
=
ldb
;
ldc_
=
ldc
;
cu_trans_a_
=
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cu_trans_b_
=
trans_b
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
// create operation desciriptor; see cublasLtMatmulDescAttributes_t for
// details about defaults; here we just need to set the transforms for A and B
CUBLAS_CALL
(
cublasLtMatmulDescCreate
(
&
matmul_desc_
,
cublasTypeWrapper
<
PTypeOut
>::
type
));
CUBLAS_CALL
(
cublasLtMatmulDescSetAttribute
(
matmul_desc_
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
cu_trans_b_
,
sizeof
(
cu_trans_b_
)));
CUBLAS_CALL
(
cublasLtMatmulDescSetAttribute
(
matmul_desc_
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
cu_trans_a_
,
sizeof
(
cu_trans_a_
)));
// create matrix descriptors, we are good with the details here so no need to
// set any extra attributes
CUBLAS_CALL
(
cublasLtMatrixLayoutCreate
(
&
a_desc_
,
cublasTypeWrapper
<
PTypeOut
>::
type
,
trans_a
==
false
?
k
:
m
,
trans_a
==
false
?
m
:
k
,
lda
));
CUBLAS_CALL
(
cublasLtMatrixLayoutCreate
(
&
b_desc_
,
cublasTypeWrapper
<
PTypeOut
>::
type
,
trans_b
==
false
?
n
:
k
,
trans_b
==
false
?
k
:
n
,
ldb
));
CUBLAS_CALL
(
cublasLtMatrixLayoutCreate
(
&
c_desc_
,
cublasTypeWrapper
<
PTypeOut
>::
type
,
n
,
m
,
ldc
));
// create preference handle; here we could use extra attributes to disable
// tensor ops or to make sure algo selected will work with badly aligned A, B,
// C; here for simplicity we just assume A,B,C are always well aligned (e.g.
// directly come from cudaMalloc)
CUBLAS_CALL
(
cublasLtMatmulPreferenceCreate
(
&
preference_
));
if
(
!
workspace_
)
{
CUDA_CALL
(
cudaMalloc
(
&
this
->
workspace_
,
workspace_size_
));
}
CUBLAS_CALL
(
cublasLtMatmulPreferenceSetAttribute
(
preference_
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspace_size_
,
sizeof
(
workspace_size_
)));
// we just need the best available heuristic to try and run matmul. There is
// no guarantee this will work, e.g. if A is badly aligned, you can request
// more (e.g. 32) algos and try to run them one by one until something works
CUBLAS_CALL
(
cublasLtMatmulAlgoGetHeuristic
(
handle_
,
matmul_desc_
,
b_desc_
,
a_desc_
,
c_desc_
,
c_desc_
,
preference_
,
1
,
&
heuristic_result_
,
&
returned_results_
));
if
(
returned_results_
==
0
)
{
LOG
(
FATAL
)
<<
"cuBLAS API failed with status "
<<
CUBLAS_STATUS_NOT_SUPPORTED
;
}
return
true
;
}
template
<
typename
PTypeIn
,
typename
PTypeOut
>
bool
LtGemm
<
PTypeIn
,
PTypeOut
>::
run
(
const
PTypeOut
alpha
,
const
PTypeOut
beta
,
const
PTypeIn
*
a
,
const
PTypeIn
*
b
,
PTypeOut
*
c
,
Context
<
TARGET
(
kCUDA
)
>
*
ctx
)
{
CUBLAS_CALL
(
cublasLtMatmul
(
handle_
,
matmul_desc_
,
&
alpha
,
b
,
b_desc_
,
a
,
a_desc_
,
&
beta
,
c
,
c_desc_
,
c
,
c_desc_
,
&
heuristic_result_
.
algo
,
workspace_
,
workspace_size_
,
this
->
exe_stream_
));
return
true
;
}
template
class
LtGemm
<
float
,
float
>;
template
class
LtGemm
<
half
,
half
>;
#endif
}
// namespace math
}
// namespace cuda
}
// namespace lite
...
...
lite/backends/cuda/math/gemm.h
浏览文件 @
46980d68
...
...
@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
...
...
@@ -70,6 +69,79 @@ class Gemm {
int
ldc_
{
-
1
};
};
#if (CUBLAS_VER_MAJOR * 10 + CUBLAS_VER_MINOR) >= 101
template
<
typename
PtypeIn
,
typename
PtypeOut
>
class
LtGemm
{
public:
LtGemm
()
:
handle_
(
nullptr
),
matmul_desc_
(
nullptr
),
a_desc_
(
nullptr
),
b_desc_
(
nullptr
),
c_desc_
(
nullptr
),
preference_
(
nullptr
),
returned_results_
(
0
),
workspace_size_
(
4
*
1024
*
1024
),
workspace_
{
nullptr
}
{}
~
LtGemm
()
{
if
(
this
->
workspace_
)
{
CUDA_CALL
(
cudaFree
(
this
->
workspace_
));
}
this
->
workspace_
=
nullptr
;
}
bool
init
(
const
bool
trans_a
,
const
bool
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
Context
<
TARGET
(
kCUDA
)
>*
ctx
);
bool
init
(
const
bool
trans_a
,
const
bool
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
int
lda
,
const
int
ldb
,
const
int
ldc
,
Context
<
TARGET
(
kCUDA
)
>*
ctx
);
bool
run
(
const
PtypeOut
alpha
,
const
PtypeOut
beta
,
const
PtypeIn
*
a
,
const
PtypeIn
*
b
,
PtypeOut
*
c
,
Context
<
TARGET
(
kCUDA
)
>*
ctx
);
cublasLtHandle_t
get_handle
()
const
{
return
handle_
;
}
private:
cudaStream_t
exe_stream_
;
cublasLtHandle_t
handle_
;
cublasLtMatmulDesc_t
matmul_desc_
;
cublasLtMatrixLayout_t
a_desc_
;
cublasLtMatrixLayout_t
b_desc_
;
cublasLtMatrixLayout_t
c_desc_
;
cublasLtMatmulPreference_t
preference_
;
int
returned_results_
;
cublasLtMatmulHeuristicResult_t
heuristic_result_
{};
cublasOperation_t
cu_trans_a_
;
cublasOperation_t
cu_trans_b_
;
int
m_
{
-
1
};
int
n_
{
-
1
};
int
k_
{
-
1
};
int
lda_
{
-
1
};
int
ldb_
{
-
1
};
int
ldc_
{
-
1
};
size_t
workspace_size_
;
void
*
workspace_
;
};
#endif
}
// namespace math
}
// namespace cuda
}
// namespace lite
...
...
lite/backends/cuda/math/gemv.cc
0 → 100644
浏览文件 @
46980d68
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/gemv.h"
#include <iostream>
#include "lite/core/device_info.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
template
<
typename
PTypeIn
,
typename
PTypeOut
>
bool
Gemv
<
PTypeIn
,
PTypeOut
>::
init
(
const
bool
trans
,
const
int
m
,
const
int
n
,
const
int
lda
,
const
int
ldb
,
const
int
ldc
,
Context
<
TARGET
(
kCUDA
)
>
*
ctx
)
{
if
(
cu_handle_
==
nullptr
)
{
this
->
exe_stream_
=
ctx
->
exec_stream
();
CUBLAS_CALL
(
cublasCreate
(
&
cu_handle_
));
CUBLAS_CALL
(
cublasSetMathMode
(
cu_handle_
,
CUBLAS_TENSOR_OP_MATH
));
CUBLAS_CALL
(
cublasSetStream
(
cu_handle_
,
this
->
exe_stream_
));
}
m_
=
m
;
n_
=
n
;
lda_
=
lda
;
ldb_
=
ldb
;
ldc_
=
ldc
;
cu_trans_
=
trans
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
return
true
;
}
template
<
>
bool
Gemv
<
float
,
float
>::
run
(
const
float
alpha
,
const
float
beta
,
const
float
*
a
,
const
float
*
b
,
float
*
c
)
{
CUBLAS_CALL
(
cublasSgemv
(
cu_handle_
,
cu_trans_
,
n_
,
m_
,
&
alpha
,
a
,
lda_
,
b
,
ldb_
,
&
beta
,
c
,
ldc_
));
return
true
;
}
template
<
>
bool
Gemv
<
half
,
half
>::
run
(
const
half
alpha
,
const
half
beta
,
const
half
*
a
,
const
half
*
b
,
half
*
c
)
{
LOG
(
FATAL
)
<<
"not supported"
;
return
false
;
}
template
class
Gemv
<
float
,
float
>;
template
class
Gemv
<
half
,
half
>;
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/backends/cuda/math/gemv.h
0 → 100644
浏览文件 @
46980d68
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
template
<
typename
PtypeIn
,
typename
PtypeOut
>
class
Gemv
{
public:
Gemv
()
:
cu_handle_
(
nullptr
)
{}
~
Gemv
()
{}
bool
init
(
const
bool
trans_
,
const
int
m
,
const
int
n
,
const
int
lda
,
const
int
ldb
,
const
int
ldc
,
Context
<
TARGET
(
kCUDA
)
>*
ctx
);
bool
run
(
const
PtypeOut
alpha
,
const
PtypeOut
beta
,
const
PtypeIn
*
a
,
const
PtypeIn
*
b
,
PtypeOut
*
c
);
cublasHandle_t
get_handle
()
const
{
return
cu_handle_
;
}
private:
cudaStream_t
exe_stream_
;
cublasHandle_t
cu_handle_
;
cublasOperation_t
cu_trans_
;
int
m_
{
-
1
};
int
n_
{
-
1
};
int
lda_
{
-
1
};
int
ldb_
{
-
1
};
int
ldc_
{
-
1
};
};
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/mul_compute.cc
浏览文件 @
46980d68
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/cuda/mul_compute.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
...
...
@@ -23,9 +24,18 @@ namespace cuda {} // namespace cuda
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
MulCompute
,
def
)
using
MulFp32
=
paddle
::
lite
::
kernels
::
cuda
::
MulCompute
<
float
,
PRECISION
(
kFloat
)
>
;
using
MulFp16
=
paddle
::
lite
::
kernels
::
cuda
::
MulCompute
<
half
,
PRECISION
(
kFP16
)
>
;
REGISTER_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
MulFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
mul
,
kCUDA
,
kFP16
,
kNCHW
,
MulFp16
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/kernels/cuda/mul_compute.h
浏览文件 @
46980d68
...
...
@@ -23,22 +23,21 @@ namespace lite {
namespace
kernels
{
namespace
cuda
{
class
MulCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
template
<
typename
T
,
PrecisionType
PType
>
class
MulCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PType
>
{
public:
using
param_t
=
operators
::
MulParam
;
void
PrepareForRun
()
override
{
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
float
,
float
>
);
gemm_impl_
.
reset
(
new
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>
);
}
void
Run
()
override
{
CHECK
(
ctx_
)
<<
"running context should be set first"
;
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
&
param
=
this
->
Param
<
param_t
>
();
const
auto
*
x_data
=
param
.
x
->
data
<
float
>
();
const
auto
*
y_data
=
param
.
y
->
data
<
float
>
();
auto
*
out_data
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
&
param
=
this
->
template
Param
<
param_t
>();
const
auto
*
x_data
=
param
.
x
->
template
data
<
T
>();
const
auto
*
y_data
=
param
.
y
->
template
data
<
T
>();
auto
*
out_data
=
param
.
output
->
template
mutable_data
<
T
>(
TARGET
(
kCUDA
));
int
x_h
=
static_cast
<
int
>
(
param
.
x
->
dims
().
Slice
(
0
,
param
.
x_num_col_dims
).
production
());
...
...
@@ -61,7 +60,7 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
virtual
~
MulCompute
()
=
default
;
private:
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
float
,
float
>>
gemm_impl_
{
nullptr
};
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
T
,
T
>>
gemm_impl_
{
nullptr
};
};
}
// namespace cuda
...
...
lite/kernels/cuda/mul_compute_test.cc
浏览文件 @
46980d68
...
...
@@ -16,58 +16,182 @@
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include "lite/backends/cuda/blas.h"
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
TEST
(
mul_compute
,
normal
)
{
MulCompute
mul_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
Tensor
x
,
y
,
out
,
x_cpu
,
y_cpu
,
out_cpu
;
int
x_h
=
2
,
x_w_y_h
=
3
,
y_w
=
4
;
out
.
Resize
({
x_h
,
y_w
});
x_cpu
.
Resize
({
x_h
,
x_w_y_h
});
y_cpu
.
Resize
({
x_w_y_h
,
y_w
});
out_cpu
.
Resize
({
x_h
,
y_w
});
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
float
*
y_cpu_data
=
y_cpu
.
mutable_data
<
float
>
();
float
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_cpu
.
numel
();
i
++
)
{
x_cpu_data
[
i
]
=
i
+
1.0
;
class
MulTest
:
public
::
testing
::
Test
{
protected:
MulTest
()
:
m_
(
2
),
k_
(
3
),
n_
(
4
),
x_shape_
({
m_
,
k_
}),
y_shape_
({
k_
,
n_
}),
out_shape_
({
m_
,
n_
})
{
x_gpu_
.
Resize
(
lite
::
DDim
(
x_shape_
));
x_ref_
.
Resize
(
lite
::
DDim
(
x_shape_
));
y_gpu_
.
Resize
(
lite
::
DDim
(
y_shape_
));
y_ref_
.
Resize
(
lite
::
DDim
(
y_shape_
));
auto
x_ref_data
=
x_ref_
.
mutable_data
<
float
>
();
auto
y_ref_data
=
y_ref_
.
mutable_data
<
float
>
();
// prepare input
for
(
int64_t
i
=
0
;
i
<
x_ref_
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
for
(
int64_t
i
=
0
;
i
<
y_ref_
.
numel
();
i
++
)
{
y_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
10
*
0.2
);
}
out_ref_
.
Resize
(
lite
::
DDim
(
out_shape_
));
out_cpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
out_gpu_
.
Resize
(
lite
::
DDim
(
out_shape_
));
RunBaseLine
(
&
x_ref_
,
&
y_ref_
,
&
out_ref_
);
InitParamAndContext
();
}
void
InitParamAndContext
()
{
ctx_
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream_
);
auto
&
context
=
ctx_
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream_
);
param_
.
x
=
&
x_gpu_
;
param_
.
y
=
&
y_gpu_
;
param_
.
output
=
&
out_gpu_
;
}
void
InitFloatInput
()
{
x_gpu_
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_ref_
.
data
<
float
>
(),
x_gpu_
.
dims
());
y_gpu_
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
y_ref_
.
data
<
float
>
(),
y_gpu_
.
dims
());
}
void
InitHalfInput
()
{
x_half_
.
Resize
(
lite
::
DDim
(
x_ref_
.
dims
()));
auto
x_half_data
=
x_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
x_half_
.
numel
();
i
++
)
{
x_half_data
[
i
]
=
half
(
lite
::
float16
(
x_ref_
.
data
<
float
>
()[
i
]));
}
x_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half_data
,
x_gpu_
.
dims
());
y_half_
.
Resize
(
y_ref_
.
dims
());
auto
y_half_data
=
y_half_
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
y_half_
.
numel
();
i
++
)
{
y_half_data
[
i
]
=
half
(
lite
::
float16
(
y_ref_
.
data
<
float
>
()[
i
]));
}
y_gpu_
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
y_half_data
,
y_gpu_
.
dims
());
}
void
RunBaseLine
(
const
lite
::
Tensor
*
x
,
const
lite
::
Tensor
*
w
,
lite
::
Tensor
*
out
)
{
const
float
*
data_in
=
x
->
data
<
float
>
();
const
float
*
weights
=
w
->
data
<
float
>
();
float
*
data_out
=
out
->
mutable_data
<
float
>
();
int
out_rows
=
x
->
dims
()[
0
];
int
in_cols
=
x
->
numel
()
/
out_rows
;
int
out_cols
=
w
->
numel
()
/
in_cols
;
int
index_out
;
for
(
int
i
=
0
;
i
<
out_rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
out_cols
;
j
++
)
{
index_out
=
i
*
out_cols
+
j
;
data_out
[
index_out
]
=
0
;
for
(
int
k
=
0
;
k
<
in_cols
;
k
++
)
{
data_out
[
index_out
]
+=
data_in
[
i
*
in_cols
+
k
]
*
weights
[
k
*
out_cols
+
j
];
}
}
}
}
for
(
int
i
=
0
;
i
<
y_cpu
.
numel
();
i
++
)
{
y_cpu_data
[
i
]
=
i
+
1.0
;
int
m_
,
k_
,
n_
;
std
::
vector
<
int64_t
>
x_shape_
,
y_shape_
,
out_shape_
;
lite
::
Tensor
x_ref_
,
y_ref_
,
out_ref_
;
lite
::
Tensor
x_gpu_
,
y_gpu_
;
lite
::
Tensor
x_half_
,
y_half_
;
lite
::
Tensor
out_cpu_
,
out_gpu_
;
operators
::
MulParam
param_
;
std
::
unique_ptr
<
KernelContext
>
ctx_
;
cudaStream_t
stream_
;
};
TEST_F
(
MulTest
,
TestFP32
)
{
InitFloatInput
();
MulCompute
<
float
,
PRECISION
(
kFloat
)
>
mul_kernel
;
mul_kernel
.
SetParam
(
param_
);
mul_kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
mul_kernel
.
Launch
();
cudaDeviceSynchronize
();
}
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
y
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
y_cpu_data
,
y_cpu
.
dims
());
auto
start
=
GetCurrentUS
();
mul_kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
mul_kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp32, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
operators
::
MulParam
param
;
param
.
x
=
&
x
;
param
.
y
=
&
y
;
param
.
output
=
&
out
;
mul_kernel
.
SetParam
(
param
);
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_
.
mutable_data
<
float
>
(),
out_gpu_
.
data
<
float
>
(),
sizeof
(
float
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
for
(
int
i
=
0
;
i
<
out_gpu_
.
numel
();
++
i
)
{
float
res
=
out_cpu_
.
data
<
float
>
()[
i
];
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
(
ref
+
1e-5
),
0.
,
1e-4
);
}
}
TEST_F
(
MulTest
,
TestFP16
)
{
InitHalfInput
();
MulCompute
<
half
,
PRECISION
(
kFP16
)
>
mul_kernel
;
mul_kernel
.
SetParam
(
param_
);
mul_kernel
.
SetContext
(
std
::
move
(
ctx_
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
mul_kernel
.
Launch
();
cudaDeviceSynchronize
();
}
mul_kernel
.
SetContext
(
std
::
move
(
ctx
));
mul_kernel
.
Launch
();
auto
start
=
GetCurrentUS
();
mul_kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
mul_kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp16, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
const
half
*
out_gpu_data
=
out_gpu_
.
data
<
half
>
();
half
*
out_cpu_data
=
out_cpu_
.
mutable_data
<
half
>
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_gpu_data
,
sizeof
(
half
)
*
out_gpu_
.
numel
(),
IoDirection
::
DtoH
);
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out_cpu
.
numel
();
i
++
)
{
LOG
(
INFO
)
<<
out_cpu_data
[
i
]
;
for
(
int
i
=
0
;
i
<
out_cpu_
.
numel
();
++
i
)
{
float
res
=
static_cast
<
float
>
(
lite
::
float16
(
out_cpu_data
[
i
])
);
float
ref
=
out_ref_
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
(
ref
+
1e-5
),
0.
,
1e-2
)
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录