From 46980d687aebea9151124193981c5ecdcddfb32c Mon Sep 17 00:00:00 2001 From: Wilber Date: Mon, 24 Aug 2020 10:51:40 +0800 Subject: [PATCH] add ltgemm gemv. test=develop (#4155) --- lite/backends/cuda/cuda_utils.h | 4 + lite/backends/cuda/math/CMakeLists.txt | 2 + lite/backends/cuda/math/gemm.cc | 150 +++++++++++++++++++ lite/backends/cuda/math/gemm.h | 74 ++++++++- lite/backends/cuda/math/gemv.cc | 73 +++++++++ lite/backends/cuda/math/gemv.h | 67 +++++++++ lite/kernels/cuda/mul_compute.cc | 14 +- lite/kernels/cuda/mul_compute.h | 17 +-- lite/kernels/cuda/mul_compute_test.cc | 200 ++++++++++++++++++++----- 9 files changed, 551 insertions(+), 50 deletions(-) create mode 100644 lite/backends/cuda/math/gemv.cc create mode 100644 lite/backends/cuda/math/gemv.h diff --git a/lite/backends/cuda/cuda_utils.h b/lite/backends/cuda/cuda_utils.h index 012004a65f..ac40a70a0f 100644 --- a/lite/backends/cuda/cuda_utils.h +++ b/lite/backends/cuda/cuda_utils.h @@ -21,6 +21,10 @@ #include #include "lite/utils/cp_logging.h" +#if (CUBLAS_VER_MAJOR * 10 + CUBLAS_VER_MINOR) >= 101 +#include +#endif + /* * This file contains some CUDA specific utils. */ diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index 495b273a30..f9877cabff 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -16,6 +16,7 @@ nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps}) nv_library(cuda_gru_forward SRCS gru_forward.cu DEPS cuda_activation ${cuda_static_deps}) nv_library(cuda_sequence2batch SRCS sequence2batch.cu DEPS ${cuda_static_deps}) nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps}) +nv_library(cuda_gemv SRCS gemv.cc DEPS ${cuda_static_deps}) nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps}) @@ -35,6 +36,7 @@ set ( cuda_gru_forward cuda_sequence2batch cuda_gemm + cuda_gemv cuda_batched_gemm cuda_strided_gemm cuda_sequence_padding diff --git a/lite/backends/cuda/math/gemm.cc b/lite/backends/cuda/math/gemm.cc index baba1d8526..ab269dec95 100644 --- a/lite/backends/cuda/math/gemm.cc +++ b/lite/backends/cuda/math/gemm.cc @@ -123,6 +123,156 @@ bool Gemm::run(const half alpha, template class Gemm; template class Gemm; +// LtGemm +template +class cublasTypeWrapper; + +template <> +class cublasTypeWrapper { + public: + static const cudaDataType_t type = CUDA_R_32F; +}; + +template <> +class cublasTypeWrapper { + public: + static const cudaDataType_t type = CUDA_R_16F; +}; + +#if (CUBLAS_VER_MAJOR * 10 + CUBLAS_VER_MINOR) >= 101 + +template +bool LtGemm::init(const bool trans_a, + const bool trans_b, + const int m, + const int n, + const int k, + Context *ctx) { + int lda = (!trans_a) ? k : m; + int ldb = (!trans_b) ? n : k; + int ldc = n; + + return this->init(trans_a, trans_b, m, n, k, lda, ldb, ldc, ctx); +} + +template +bool LtGemm::init(const bool trans_a, + const bool trans_b, + const int m, + const int n, + const int k, + const int lda, + const int ldb, + const int ldc, + Context *ctx) { + if (handle_ == nullptr) { + this->exe_stream_ = ctx->exec_stream(); + CUBLAS_CALL(cublasLtCreate(&handle_)); + } + m_ = m; + n_ = n; + k_ = k; + lda_ = lda; + ldb_ = ldb; + ldc_ = ldc; + cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; + cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; + + // create operation desciriptor; see cublasLtMatmulDescAttributes_t for + // details about defaults; here we just need to set the transforms for A and B + CUBLAS_CALL(cublasLtMatmulDescCreate(&matmul_desc_, + cublasTypeWrapper::type)); + CUBLAS_CALL(cublasLtMatmulDescSetAttribute(matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &cu_trans_b_, + sizeof(cu_trans_b_))); + CUBLAS_CALL(cublasLtMatmulDescSetAttribute(matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &cu_trans_a_, + sizeof(cu_trans_a_))); + + // create matrix descriptors, we are good with the details here so no need to + // set any extra attributes + CUBLAS_CALL(cublasLtMatrixLayoutCreate(&a_desc_, + cublasTypeWrapper::type, + trans_a == false ? k : m, + trans_a == false ? m : k, + lda)); + CUBLAS_CALL(cublasLtMatrixLayoutCreate(&b_desc_, + cublasTypeWrapper::type, + trans_b == false ? n : k, + trans_b == false ? k : n, + ldb)); + CUBLAS_CALL(cublasLtMatrixLayoutCreate( + &c_desc_, cublasTypeWrapper::type, n, m, ldc)); + + // create preference handle; here we could use extra attributes to disable + // tensor ops or to make sure algo selected will work with badly aligned A, B, + // C; here for simplicity we just assume A,B,C are always well aligned (e.g. + // directly come from cudaMalloc) + CUBLAS_CALL(cublasLtMatmulPreferenceCreate(&preference_)); + + if (!workspace_) { + CUDA_CALL(cudaMalloc(&this->workspace_, workspace_size_)); + } + CUBLAS_CALL(cublasLtMatmulPreferenceSetAttribute( + preference_, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size_, + sizeof(workspace_size_))); + + // we just need the best available heuristic to try and run matmul. There is + // no guarantee this will work, e.g. if A is badly aligned, you can request + // more (e.g. 32) algos and try to run them one by one until something works + CUBLAS_CALL(cublasLtMatmulAlgoGetHeuristic(handle_, + matmul_desc_, + b_desc_, + a_desc_, + c_desc_, + c_desc_, + preference_, + 1, + &heuristic_result_, + &returned_results_)); + if (returned_results_ == 0) { + LOG(FATAL) << "cuBLAS API failed with status " + << CUBLAS_STATUS_NOT_SUPPORTED; + } + + return true; +} + +template +bool LtGemm::run(const PTypeOut alpha, + const PTypeOut beta, + const PTypeIn *a, + const PTypeIn *b, + PTypeOut *c, + Context *ctx) { + CUBLAS_CALL(cublasLtMatmul(handle_, + matmul_desc_, + &alpha, + b, + b_desc_, + a, + a_desc_, + &beta, + c, + c_desc_, + c, + c_desc_, + &heuristic_result_.algo, + workspace_, + workspace_size_, + this->exe_stream_)); + return true; +} + +template class LtGemm; +template class LtGemm; + +#endif + } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/backends/cuda/math/gemm.h b/lite/backends/cuda/math/gemm.h index 85576e6501..71e7dbab4e 100644 --- a/lite/backends/cuda/math/gemm.h +++ b/lite/backends/cuda/math/gemm.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include #include #include #include "lite/api/paddle_place.h" @@ -70,6 +69,79 @@ class Gemm { int ldc_{-1}; }; +#if (CUBLAS_VER_MAJOR * 10 + CUBLAS_VER_MINOR) >= 101 + +template +class LtGemm { + public: + LtGemm() + : handle_(nullptr), + matmul_desc_(nullptr), + a_desc_(nullptr), + b_desc_(nullptr), + c_desc_(nullptr), + preference_(nullptr), + returned_results_(0), + workspace_size_(4 * 1024 * 1024), + workspace_{nullptr} {} + + ~LtGemm() { + if (this->workspace_) { + CUDA_CALL(cudaFree(this->workspace_)); + } + this->workspace_ = nullptr; + } + bool init(const bool trans_a, + const bool trans_b, + const int m, + const int n, + const int k, + Context* ctx); + bool init(const bool trans_a, + const bool trans_b, + const int m, + const int n, + const int k, + const int lda, + const int ldb, + const int ldc, + Context* ctx); + + bool run(const PtypeOut alpha, + const PtypeOut beta, + const PtypeIn* a, + const PtypeIn* b, + PtypeOut* c, + Context* ctx); + + cublasLtHandle_t get_handle() const { return handle_; } + + private: + cudaStream_t exe_stream_; + + cublasLtHandle_t handle_; + cublasLtMatmulDesc_t matmul_desc_; + cublasLtMatrixLayout_t a_desc_; + cublasLtMatrixLayout_t b_desc_; + cublasLtMatrixLayout_t c_desc_; + cublasLtMatmulPreference_t preference_; + int returned_results_; + cublasLtMatmulHeuristicResult_t heuristic_result_{}; + + cublasOperation_t cu_trans_a_; + cublasOperation_t cu_trans_b_; + int m_{-1}; + int n_{-1}; + int k_{-1}; + int lda_{-1}; + int ldb_{-1}; + int ldc_{-1}; + + size_t workspace_size_; + void* workspace_; +}; +#endif + } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/backends/cuda/math/gemv.cc b/lite/backends/cuda/math/gemv.cc new file mode 100644 index 0000000000..f126d8f0ef --- /dev/null +++ b/lite/backends/cuda/math/gemv.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/gemv.h" + +#include + +#include "lite/core/device_info.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +bool Gemv::init(const bool trans, + const int m, + const int n, + const int lda, + const int ldb, + const int ldc, + Context *ctx) { + if (cu_handle_ == nullptr) { + this->exe_stream_ = ctx->exec_stream(); + CUBLAS_CALL(cublasCreate(&cu_handle_)); + CUBLAS_CALL(cublasSetMathMode(cu_handle_, CUBLAS_TENSOR_OP_MATH)); + CUBLAS_CALL(cublasSetStream(cu_handle_, this->exe_stream_)); + } + m_ = m; + n_ = n; + lda_ = lda; + ldb_ = ldb; + ldc_ = ldc; + cu_trans_ = trans ? CUBLAS_OP_N : CUBLAS_OP_T; + return true; +} + +template <> +bool Gemv::run(const float alpha, + const float beta, + const float *a, + const float *b, + float *c) { + CUBLAS_CALL(cublasSgemv( + cu_handle_, cu_trans_, n_, m_, &alpha, a, lda_, b, ldb_, &beta, c, ldc_)); + return true; +} + +template <> +bool Gemv::run( + const half alpha, const half beta, const half *a, const half *b, half *c) { + LOG(FATAL) << "not supported"; + return false; +} + +template class Gemv; +template class Gemv; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/gemv.h b/lite/backends/cuda/math/gemv.h new file mode 100644 index 0000000000..3fe5a01947 --- /dev/null +++ b/lite/backends/cuda/math/gemv.h @@ -0,0 +1,67 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +class Gemv { + public: + Gemv() : cu_handle_(nullptr) {} + ~Gemv() {} + + bool init(const bool trans_, + const int m, + const int n, + const int lda, + const int ldb, + const int ldc, + Context* ctx); + + bool run(const PtypeOut alpha, + const PtypeOut beta, + const PtypeIn* a, + const PtypeIn* b, + PtypeOut* c); + + cublasHandle_t get_handle() const { return cu_handle_; } + + private: + cudaStream_t exe_stream_; + cublasHandle_t cu_handle_; + cublasOperation_t cu_trans_; + int m_{-1}; + int n_{-1}; + int lda_{-1}; + int ldb_{-1}; + int ldc_{-1}; +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/mul_compute.cc b/lite/kernels/cuda/mul_compute.cc index f59e9e5046..61c1c5b220 100644 --- a/lite/kernels/cuda/mul_compute.cc +++ b/lite/kernels/cuda/mul_compute.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/kernels/cuda/mul_compute.h" + #include "lite/core/op_registry.h" namespace paddle { @@ -23,9 +24,18 @@ namespace cuda {} // namespace cuda } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - mul, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::MulCompute, def) +using MulFp32 = + paddle::lite::kernels::cuda::MulCompute; +using MulFp16 = paddle::lite::kernels::cuda::MulCompute; + +REGISTER_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, MulFp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); + +REGISTER_LITE_KERNEL(mul, kCUDA, kFP16, kNCHW, MulFp16, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .Finalize(); diff --git a/lite/kernels/cuda/mul_compute.h b/lite/kernels/cuda/mul_compute.h index aa80919920..1fe469750e 100644 --- a/lite/kernels/cuda/mul_compute.h +++ b/lite/kernels/cuda/mul_compute.h @@ -23,22 +23,21 @@ namespace lite { namespace kernels { namespace cuda { -class MulCompute : public KernelLite { +template +class MulCompute : public KernelLite { public: using param_t = operators::MulParam; void PrepareForRun() override { - gemm_impl_.reset(new lite::cuda::math::Gemm); + gemm_impl_.reset(new lite::cuda::math::Gemm); } void Run() override { - CHECK(ctx_) << "running context should be set first"; auto& context = this->ctx_->template As(); - - auto& param = this->Param(); - const auto* x_data = param.x->data(); - const auto* y_data = param.y->data(); - auto* out_data = param.output->mutable_data(TARGET(kCUDA)); + auto& param = this->template Param(); + const auto* x_data = param.x->template data(); + const auto* y_data = param.y->template data(); + auto* out_data = param.output->template mutable_data(TARGET(kCUDA)); int x_h = static_cast( param.x->dims().Slice(0, param.x_num_col_dims).production()); @@ -61,7 +60,7 @@ class MulCompute : public KernelLite { virtual ~MulCompute() = default; private: - std::unique_ptr> gemm_impl_{nullptr}; + std::unique_ptr> gemm_impl_{nullptr}; }; } // namespace cuda diff --git a/lite/kernels/cuda/mul_compute_test.cc b/lite/kernels/cuda/mul_compute_test.cc index 60bee07694..27ada603ad 100644 --- a/lite/kernels/cuda/mul_compute_test.cc +++ b/lite/kernels/cuda/mul_compute_test.cc @@ -16,58 +16,182 @@ #include #include #include -#include "lite/backends/cuda/blas.h" +#include +#include "lite/api/test_helper.h" +#include "lite/utils/float16.h" namespace paddle { namespace lite { namespace kernels { namespace cuda { -TEST(mul_compute, normal) { - MulCompute mul_kernel; - std::unique_ptr ctx(new KernelContext); - auto& context = ctx->As(); - - Tensor x, y, out, x_cpu, y_cpu, out_cpu; - int x_h = 2, x_w_y_h = 3, y_w = 4; - out.Resize({x_h, y_w}); - x_cpu.Resize({x_h, x_w_y_h}); - y_cpu.Resize({x_w_y_h, y_w}); - out_cpu.Resize({x_h, y_w}); - - auto* out_data = out.mutable_data(TARGET(kCUDA)); - float* x_cpu_data = x_cpu.mutable_data(); - float* y_cpu_data = y_cpu.mutable_data(); - float* out_cpu_data = out_cpu.mutable_data(); - - for (int i = 0; i < x_cpu.numel(); i++) { - x_cpu_data[i] = i + 1.0; +class MulTest : public ::testing::Test { + protected: + MulTest() + : m_(2), + k_(3), + n_(4), + x_shape_({m_, k_}), + y_shape_({k_, n_}), + out_shape_({m_, n_}) { + x_gpu_.Resize(lite::DDim(x_shape_)); + x_ref_.Resize(lite::DDim(x_shape_)); + + y_gpu_.Resize(lite::DDim(y_shape_)); + y_ref_.Resize(lite::DDim(y_shape_)); + + auto x_ref_data = x_ref_.mutable_data(); + auto y_ref_data = y_ref_.mutable_data(); + + // prepare input + for (int64_t i = 0; i < x_ref_.numel(); i++) { + x_ref_data[i] = static_cast(i % 10 * 0.2); + } + for (int64_t i = 0; i < y_ref_.numel(); i++) { + y_ref_data[i] = static_cast(i % 10 * 0.2); + } + + out_ref_.Resize(lite::DDim(out_shape_)); + out_cpu_.Resize(lite::DDim(out_shape_)); + out_gpu_.Resize(lite::DDim(out_shape_)); + RunBaseLine(&x_ref_, &y_ref_, &out_ref_); + + InitParamAndContext(); + } + + void InitParamAndContext() { + ctx_.reset(new KernelContext); + cudaStreamCreate(&stream_); + auto& context = ctx_->As(); + context.SetExecStream(stream_); + param_.x = &x_gpu_; + param_.y = &y_gpu_; + param_.output = &out_gpu_; + } + + void InitFloatInput() { + x_gpu_.Assign(x_ref_.data(), + x_gpu_.dims()); + y_gpu_.Assign(y_ref_.data(), + y_gpu_.dims()); + } + + void InitHalfInput() { + x_half_.Resize(lite::DDim(x_ref_.dims())); + auto x_half_data = x_half_.mutable_data(); + for (int64_t i = 0; i < x_half_.numel(); i++) { + x_half_data[i] = half(lite::float16(x_ref_.data()[i])); + } + x_gpu_.Assign(x_half_data, x_gpu_.dims()); + y_half_.Resize(y_ref_.dims()); + auto y_half_data = y_half_.mutable_data(); + for (int64_t i = 0; i < y_half_.numel(); i++) { + y_half_data[i] = half(lite::float16(y_ref_.data()[i])); + } + y_gpu_.Assign(y_half_data, y_gpu_.dims()); + } + + void RunBaseLine(const lite::Tensor* x, + const lite::Tensor* w, + lite::Tensor* out) { + const float* data_in = x->data(); + const float* weights = w->data(); + float* data_out = out->mutable_data(); + int out_rows = x->dims()[0]; + int in_cols = x->numel() / out_rows; + int out_cols = w->numel() / in_cols; + int index_out; + for (int i = 0; i < out_rows; i++) { + for (int j = 0; j < out_cols; j++) { + index_out = i * out_cols + j; + data_out[index_out] = 0; + for (int k = 0; k < in_cols; k++) { + data_out[index_out] += + data_in[i * in_cols + k] * weights[k * out_cols + j]; + } + } + } } - for (int i = 0; i < y_cpu.numel(); i++) { - y_cpu_data[i] = i + 1.0; + + int m_, k_, n_; + std::vector x_shape_, y_shape_, out_shape_; + lite::Tensor x_ref_, y_ref_, out_ref_; + lite::Tensor x_gpu_, y_gpu_; + lite::Tensor x_half_, y_half_; + lite::Tensor out_cpu_, out_gpu_; + + operators::MulParam param_; + std::unique_ptr ctx_; + cudaStream_t stream_; +}; + +TEST_F(MulTest, TestFP32) { + InitFloatInput(); + MulCompute mul_kernel; + mul_kernel.SetParam(param_); + mul_kernel.SetContext(std::move(ctx_)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + mul_kernel.Launch(); + cudaDeviceSynchronize(); } - x.Assign(x_cpu_data, x_cpu.dims()); - y.Assign(y_cpu_data, y_cpu.dims()); + auto start = GetCurrentUS(); + mul_kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + mul_kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp32, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; - operators::MulParam param; - param.x = &x; - param.y = &y; - param.output = &out; - mul_kernel.SetParam(param); + CopySync(out_cpu_.mutable_data(), + out_gpu_.data(), + sizeof(float) * out_gpu_.numel(), + IoDirection::DtoH); - cudaStream_t stream; - cudaStreamCreate(&stream); - context.SetExecStream(stream); + for (int i = 0; i < out_gpu_.numel(); ++i) { + float res = out_cpu_.data()[i]; + float ref = out_ref_.data()[i]; + EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-4); + } +} + +TEST_F(MulTest, TestFP16) { + InitHalfInput(); + MulCompute mul_kernel; + mul_kernel.SetParam(param_); + mul_kernel.SetContext(std::move(ctx_)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + mul_kernel.Launch(); + cudaDeviceSynchronize(); + } - mul_kernel.SetContext(std::move(ctx)); - mul_kernel.Launch(); + auto start = GetCurrentUS(); + mul_kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + mul_kernel.Run(); + } cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp16, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + const half* out_gpu_data = out_gpu_.data(); + half* out_cpu_data = out_cpu_.mutable_data(); + CopySync(out_cpu_data, + out_gpu_data, + sizeof(half) * out_gpu_.numel(), + IoDirection::DtoH); - CopySync( - out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); - for (int i = 0; i < out_cpu.numel(); i++) { - LOG(INFO) << out_cpu_data[i]; + for (int i = 0; i < out_cpu_.numel(); ++i) { + float res = static_cast(lite::float16(out_cpu_data[i])); + float ref = out_ref_.data()[i]; + EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2); } } -- GitLab