未验证 提交 498f147d 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Add matmul cuda kernel. (#3897)

上级 4776f8f4
......@@ -106,7 +106,8 @@ lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kerne
lite_option(LITE_BUILD_TAILOR "Enable tailoring library according to model" OFF)
# cv build options
lite_option(LITE_WITH_CV "Enable build cv image in lite" OFF)
lite_option(LITE_WITH_STATIC_CUDA "Statically link cuda libraries." ON)
lite_option(LITE_WITH_STATIC_CUDA "Statically link cuda libraries." OFF)
lite_option(CUDA_WITH_FP16 "Compile with cuda half support" OFF)
lite_option(LITE_WITH_ARM_CLANG "when arm lang is clang, its ON." OFF)
# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter.
......
......@@ -2,6 +2,10 @@ if(NOT LITE_WITH_CUDA)
return()
endif()
if(WITH_CUDA_FP16)
add_definitions("-DCUDA_WITH_FP16")
endif()
set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs7 "30 35 50 52")
set(paddle_known_gpu_archs8 "30 35 50 52 53 60 61 62")
......@@ -167,6 +171,10 @@ elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
endif()
if (CUDA_WITH_FP16)
STRING(REGEX REPLACE "30|35|50|52" "" paddle_known_gpu_archs ${paddle_known_gpu_archs})
endif()
include_directories(${CUDA_INCLUDE_DIRS})
if(NOT WITH_DSO)
if(WIN32)
......
......@@ -13,6 +13,7 @@ nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps})
set (
......@@ -26,6 +27,7 @@ set (
cudnn_pool
cuda_gemm
cuda_batched_gemm
cuda_strided_gemm
cuda_sequence_padding
)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/strided_gemm.h"
#include <iostream>
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename PtypeIn, typename PtypeOut>
bool StridedGemm<PtypeIn, PtypeOut>::init(const bool trans_a,
const bool trans_b,
Context<TARGET(kCUDA)>* ctx) {
if (cu_handle_ == nullptr) {
this->exe_stream_ = ctx->exec_stream();
CUBLAS_CALL(cublasCreate(&cu_handle_));
CUBLAS_CALL(cublasSetStream(cu_handle_, this->exe_stream_));
}
cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
return true;
}
template <>
bool StridedGemm<float, float>::run(const float alpha,
const float beta,
const int m,
const int n,
const int k,
const float* a_data,
const float* b_data,
float* c_data,
const int batch_size,
const int64_t stride_a,
const int64_t stride_b) {
lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
ldc_ = n;
m_ = m;
n_ = n;
k_ = k;
const int64_t stride_c = m_ * n_;
CUBLAS_CALL(cublasGemmStridedBatchedEx(cu_handle_,
cu_trans_b_,
cu_trans_a_,
n_,
m_,
k_,
&alpha,
b_data,
CUDA_R_32F,
ldb_,
stride_b,
a_data,
CUDA_R_32F,
lda_,
stride_a,
&beta,
c_data,
CUDA_R_32F,
ldc_,
stride_c,
batch_size,
CUDA_R_32F,
algo_));
return true;
}
template <>
bool StridedGemm<half, half>::run(const half alpha,
const half beta,
const int m,
const int n,
const int k,
const half* a_data,
const half* b_data,
half* c_data,
const int batch_size,
const int64_t stride_a,
const int64_t stride_b) {
lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
ldc_ = n;
m_ = m;
n_ = n;
k_ = k;
const int64_t stride_c = m_ * n_;
CUBLAS_CALL(cublasGemmStridedBatchedEx(cu_handle_,
cu_trans_b_,
cu_trans_a_,
n_,
m_,
k_,
&alpha,
b_data,
CUDA_R_16F,
ldb_,
stride_b,
a_data,
CUDA_R_16F,
lda_,
stride_a,
&beta,
c_data,
CUDA_R_16F,
ldc_,
stride_c,
batch_size,
CUDA_R_16F,
algo_));
return true;
}
template class StridedGemm<float, float>;
template class StridedGemm<half, half>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename PtypeIn, typename PtypeOut>
class StridedGemm {
public:
StridedGemm() : cu_handle_(nullptr) {}
~StridedGemm() {}
bool init(const bool trans_a,
const bool trans_b,
Context<TARGET(kCUDA)>* ctx);
bool run(const PtypeIn alpha,
const PtypeIn beta,
const int m,
const int n,
const int k,
const PtypeIn* a_data,
const PtypeIn* b_data,
PtypeOut* c_data,
const int batch_size,
const int64_t stride_a,
const int64_t stride_b);
private:
cudaStream_t exe_stream_;
cublasHandle_t cu_handle_;
cublasOperation_t cu_trans_a_;
cublasOperation_t cu_trans_b_;
int m_{-1};
int n_{-1};
int k_{-1};
int lda_{-1};
int ldb_{-1};
int ldc_{-1};
cublasGemmAlgo_t algo_{CUBLAS_GEMM_DEFAULT_TENSOR_OP};
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(matmul_compute_cuda CUDA basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
......@@ -68,6 +69,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda)
nv_test(matmul_compute_cuda_test SRCS matmul_compute_test.cc DEPS matmul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
#nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
......
......@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/fc_compute.h"
#include <string>
......@@ -32,6 +31,74 @@ struct FcTypeTraits<float> {
typedef float4 Type;
};
template <typename T>
__global__ void AddBiasV2(const int num, const T* bias, T* data, int K) {
CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K;
const T bias_ptr = bias[bias_idx];
const T in_ptr = data[index];
T packed_val;
packed_val.x = in_ptr.x + bias_ptr.x;
packed_val.y = in_ptr.y + bias_ptr.y;
data[index] = packed_val;
}
}
template <>
__global__ void AddBiasV2(const int num,
const half2* bias,
half2* data,
int K) {
CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K;
const half2 bias_ptr = bias[bias_idx];
const half2 in_ptr = data[index];
#if __CUDA_ARCH__ >= 530
data[index] = __hadd2(in_ptr, bias_ptr);
#else
half2 packed_val;
packed_val.x = __hadd(in_ptr.x, bias_ptr.x);
packed_val.y = __hadd(in_ptr.y, bias_ptr.y);
data[index] = packed_val;
#endif
}
}
template <typename T>
__global__ void AddBiasReluV2(const int num, const T* bias, T* data, int K) {
CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K;
const T bias_ptr = bias[bias_idx];
const T in_ptr = data[index];
T packed_val;
packed_val.x = fmaxf(0.f, in_ptr.x + bias_ptr.x);
packed_val.y = fmaxf(0.f, in_ptr.y + bias_ptr.y);
data[index] = packed_val;
}
}
template <>
__global__ void AddBiasReluV2(const int num,
const half2* bias,
half2* data,
int K) {
CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K;
const half2 bias_ptr = bias[bias_idx];
const half2 in_ptr = data[index];
#if __CUDA_ARCH__ >= 530
data[index] = __hmul2(__hgt2(in_ptr + bias_ptr, __float2half2_rn(0.f)),
in_ptr + bias_ptr);
#else
const float2 bias = __half22float2(bias_ptr);
const float2 in = __half22float2(in_ptr);
data[index] = __floats2half2_rn(
bias.x + in.x > 0.0f ? static_cast<float>(bias.x + in.x) : 0.0f,
bias.y + in.y > 0.0f ? static_cast<float>(bias.y + in.y) : 0.0f);
#endif
}
}
template <typename T>
__global__ void AddBiasV4(const int num, const T* bias, T* data, int K) {
CUDA_KERNEL_LOOP(index, num) {
......@@ -77,6 +144,21 @@ __global__ void AddBias(const int num, const T* bias, T* data) {
}
}
template <>
__global__ void AddBias(const int num, const half* bias, half* data) {
int offset = blockIdx.x * num;
for (int i = threadIdx.x; i < num; i += blockDim.x) {
half temp;
#if __CUDA_ARCH__ >= 350
temp = __hadd(__ldg(data + offset + i), __ldg(bias + i));
#else
temp = __hadd(data[offset + i], bias[i]);
#endif
data[offset + i] = temp;
}
}
template <typename T>
__global__ void AddBiasRelu(const int num, const T* bias, T* data) {
int offset = blockIdx.x * num;
......@@ -92,6 +174,28 @@ __global__ void AddBiasRelu(const int num, const T* bias, T* data) {
}
}
template <>
__global__ void AddBiasRelu<half>(const int num, const half* bias, half* data) {
int offset = blockIdx.x * num;
for (int i = threadIdx.x; i < num; i += blockDim.x) {
half temp;
#if __CUDA_ARCH__ >= 350
temp = __hadd(__ldg(data + offset + i), __ldg(bias + i));
#else
temp = __hadd(data[offset + i], bias[i]);
#endif
#if __CUDA_ARCH__ >= 530
data[offset + i] =
__hgt(temp, __float2half(0.0f)) ? temp : __float2half(0.0f);
#else
data[offset + i] =
__float2half(__half2float(temp) > 0.f ? __half2float(temp) : 0.f);
#endif
}
}
template <typename T, PrecisionType PType>
void FcCompute<T, PType>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
......@@ -161,6 +265,69 @@ void FcCompute<T, PType>::Run() {
}
}
template <>
void FcCompute<half, PRECISION(kFP16)>::Run() {
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto& param = this->template Param<param_t>();
const auto* x_data = param.input->template data<half>();
const auto* w_data = param.w->template data<half>();
const auto* b_data = param.bias ? param.bias->template data<half>() : nullptr;
auto out_vec = param.output->dims().Vectorize();
out_vec.back() = param.w->dims()[1];
param.output->Resize(out_vec);
auto* out_data = param.output->template mutable_data<half>(TARGET(kCUDA));
int in_num_col_dims = param.in_num_col_dims;
int M = static_cast<int>(
param.input->dims().Slice(0, param.in_num_col_dims).production());
int K = static_cast<int>(
param.input->dims()
.Slice(param.in_num_col_dims, param.input->dims().size())
.production());
int K2 = static_cast<int>(param.w->dims()[0]);
int N = static_cast<int>(param.w->dims()[1]);
CHECK_EQ(K, K2) << "x_w must be equal with y_h";
CHECK(gemm_impl_->init(false, false, M, N, K, &context));
gemm_impl_->run(1.0f, 0.0f, x_data, w_data, out_data, &context);
if (b_data == nullptr) {
return;
}
std::string activation_type = param.activation_type;
if (N % 2 == 0) {
const int threads = 256;
const int num = M * N / 2;
const int blocks = (num + threads - 1) / threads;
const auto* bias_ptr_v2 = reinterpret_cast<const half2*>(b_data);
auto* data_ptr_v2 = reinterpret_cast<half2*>(out_data);
if (activation_type == "relu") {
AddBiasReluV2<half2><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
} else if (activation_type == "") {
AddBiasV2<half2><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
} else {
LOG(FATAL) << "not supported activation type: " << activation_type;
}
} else {
const int threads = 256;
const int blocks = M;
if (activation_type == "relu") {
AddBiasRelu<half><<<blocks, threads, 0, stream>>>(N, b_data, out_data);
} else if (activation_type == "") {
AddBias<half><<<blocks, threads, 0, stream>>>(N, b_data, out_data);
} else {
LOG(FATAL) << "not supported activation type: " << activation_type;
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
......@@ -168,9 +335,19 @@ void FcCompute<T, PType>::Run() {
using FcFp32 = paddle::lite::kernels::cuda::FcCompute<float, PRECISION(kFloat)>;
using FcFp16 = paddle::lite::kernels::cuda::FcCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(fc, kCUDA, kFloat, kNCHW, FcFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(fc, kCUDA, kFP16, kNCHW, FcFp16, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
......@@ -31,8 +31,8 @@ namespace cuda {
class FcTest : public ::testing::Test {
protected:
FcTest()
: m_(128),
k_(512),
: m_(8),
k_(16),
n_(64),
in_num_col_dims_(1),
act_type_("relu"),
......@@ -189,6 +189,42 @@ TEST_F(FcTest, TestFP32) {
}
}
TEST_F(FcTest, TestFP16) {
InitHalfInput();
FcCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 2e-2);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/matmul_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType PType>
void MatMulCompute<T, PType>::Run() {
auto& context = this->ctx_->template As<CUDAContext>();
auto& param = this->template Param<param_t>();
const auto* x_data = param.X->template data<T>();
const auto* y_data = param.Y->template data<T>();
auto* out_data = param.Out->template mutable_data<T>(TARGET(kCUDA));
bool transpose_x = param.transpose_X;
bool transpose_y = param.transpose_Y;
float alpha = param.alpha;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int m = 0;
int k = 0;
int n = 0;
int batch = 0;
int64_t stride_x = 0;
int64_t stride_y = 0;
if (x_dims.size() >= 2 && y_dims.size() >= 2 &&
(x_dims.size() != 2 || y_dims.size() != 2)) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
// or
// x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [M, K], y: [B, K, N], out: [B, M, N]
strided_gemm_impl_->init(transpose_x, transpose_y, &context);
m = transpose_x ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2];
k = transpose_x ? x_dims[x_dims.size() - 2] : x_dims[x_dims.size() - 1];
n = transpose_y ? y_dims[y_dims.size() - 2] : y_dims[y_dims.size() - 1];
int batch_x = x_dims.size() == 2 ? 0 : x_dims.count(0, x_dims.size() - 2);
int batch_y = y_dims.size() == 2 ? 0 : y_dims.count(0, y_dims.size() - 2);
CHECK(batch_x == batch_y || batch_x == 0 || batch_y == 0)
<< "batch_size x should be equal to batch_size y, or "
"one of batch_size x and batch_size y should be 0. "
"But got batch_size x = "
<< batch_x << ", batch_size y = " << batch_y;
batch = batch_x == 0 ? batch_y : batch_x;
stride_x = x_dims.size() == 2 ? 0 : m * k;
stride_y = y_dims.size() == 2 ? 0 : k * n;
strided_gemm_impl_->run(alpha,
0.f,
m,
n,
k,
x_data,
y_data,
out_data,
batch,
stride_x,
stride_y);
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N]
m = transpose_x ? x_dims[1] : x_dims[0];
k = transpose_x ? x_dims[0] : x_dims[1];
n = transpose_y ? y_dims[0] : y_dims[1];
gemm_impl_->init(transpose_x, transpose_y, m, n, k, &context);
gemm_impl_->run(alpha, 0.0f, x_data, y_data, out_data, &context);
} else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M]
strided_gemm_impl_->init(transpose_x, transpose_y, &context);
m = transpose_x ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2];
k = transpose_x ? x_dims[x_dims.size() - 2] : x_dims[x_dims.size() - 1];
n = 1;
batch = x_dims.count(0, x_dims.size() - 2);
stride_x = m * k;
stride_y = 0;
strided_gemm_impl_->run(alpha,
0.f,
m,
n,
k,
x_data,
y_data,
out_data,
batch,
stride_x,
stride_y);
} else if (x_dims.size() == 1 && y_dims.size() == 1) {
if (!transpose_x && !transpose_y) {
// x: [K], y: [K], out: [1]
m = 1;
k = x_dims[0];
n = 1;
CHECK_EQ(x_dims[0], y_dims[0])
<< "x_dims[0] should be equal to y_dims[0]";
gemm_impl_->init(false, false, m, n, k, &context);
gemm_impl_->run(alpha, 0.0f, x_data, y_data, out_data, &context);
} else if (transpose_x && transpose_y) {
// x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N]
m = x_dims[0];
k = 1;
n = y_dims[0];
gemm_impl_->init(false, false, m, n, k, &context);
gemm_impl_->run(alpha, 0.0f, x_data, y_data, out_data, &context);
} else {
LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims("
<< y_dims << "), transpose_x(" << transpose_x
<< "), transpose_y(" << transpose_y << ")";
}
} else {
LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using MatMulFp32 =
paddle::lite::kernels::cuda::MatMulCompute<float, PRECISION(kFloat)>;
using MatMulFp16 =
paddle::lite::kernels::cuda::MatMulCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(matmul, kCUDA, kFloat, kNCHW, MatMulFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(matmul, kCUDA, kFP16, kNCHW, MatMulFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/backends/cuda/math/strided_gemm.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class MatMulCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::MatMulParam;
void PrepareForRun() override {
strided_gemm_impl_.reset(new lite::cuda::math::StridedGemm<T, T>);
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
}
void Run() override;
virtual ~MatMulCompute() = default;
private:
std::unique_ptr<lite::cuda::math::StridedGemm<T, T>> strided_gemm_impl_{
nullptr};
std::unique_ptr<lite::cuda::math::Gemm<T, T>> gemm_impl_{nullptr};
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/matmul_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class MatMulTest : public ::testing::Test {
protected:
MatMulTest()
: x_trans_(false),
y_trans_(true),
alpha_(1.0f),
x_shape_({4, 1, 2}),
y_shape_({4, 1, 2}),
out_shape_({4, 1, 1}) {
x_ref_.Resize(lite::DDim(x_shape_));
x_gpu_.Resize(x_ref_.dims());
y_ref_.Resize(lite::DDim(y_shape_));
y_gpu_.Resize(y_ref_.dims());
auto x_ref_data = x_ref_.mutable_data<float>();
auto y_ref_data = y_ref_.mutable_data<float>();
// prepare input
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(1);
}
for (int64_t i = 0; i < y_ref_.numel(); i++) {
y_ref_data[i] = static_cast<float>(1);
}
out_ref_.Resize(lite::DDim(out_shape_));
out_cpu_.Resize(out_ref_.dims());
out_gpu_.Resize(out_ref_.dims());
RunBaseLine();
InitParamAndContext();
}
void InitParamAndContext() {
ctx_.reset(new KernelContext);
cudaStreamCreate(&stream_);
auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream_);
param_.X = &x_gpu_;
param_.Y = &y_gpu_;
param_.transpose_X = x_trans_;
param_.transpose_Y = y_trans_;
param_.alpha = alpha_;
param_.Out = &out_gpu_;
}
void InitFloatInput() {
x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
x_gpu_.dims());
y_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(y_ref_.data<float>(),
y_gpu_.dims());
}
void InitHalfInput() {
x_half_.Resize(x_ref_.dims());
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); ++i) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
y_half_.Resize(y_ref_.dims());
auto y_half_data = y_half_.mutable_data<half>();
for (int64_t i = 0; i < y_half_.numel(); i++) {
y_half_data[i] = half(lite::float16(y_ref_.data<float>()[i]));
}
y_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(y_half_data, y_gpu_.dims());
}
void RunBaseLine() {
auto* out_data = out_ref_.mutable_data<float>();
for (int64_t i = 0; i < out_ref_.numel(); ++i) {
out_data[i] = 2;
}
}
bool x_trans_, y_trans_;
float alpha_;
std::vector<int64_t> x_shape_, y_shape_, out_shape_;
lite::Tensor x_ref_, y_ref_, out_ref_;
lite::Tensor x_gpu_, y_gpu_;
lite::Tensor x_half_, y_half_;
lite::Tensor out_cpu_, out_gpu_;
operators::MatMulParam param_;
std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream_;
};
TEST_F(MatMulTest, TestFP32) {
InitFloatInput();
MatMulCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
out_gpu_.data<float>(),
sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = out_cpu_.data<float>()[i];
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5);
}
}
TEST_F(MatMulTest, TestFP16) {
InitHalfInput();
MatMulCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -71,7 +71,7 @@ void Nchw2nhwcBaseLine(lite::Tensor* input,
n * output_h * output_w * output_c]
void Nhwc2nchwBaseLine(lite::Tensor* input,
lite::Tensor* output,
const std::vector<int> axies) {
const std::vector<int>& axies) {
auto* input_data = input->data<float>();
auto* output_data = output->mutable_data<float>();
......@@ -175,7 +175,6 @@ TEST(transpose_nchw, normal) {
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
Nchw2nhwcBaseLine(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>();
// TransBaseLine(&x_ref, &out_ref, axes);
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
......@@ -226,7 +225,6 @@ TEST(transpose_nhwc, normal) {
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
Nhwc2nchwBaseLine(&x_ref, &out_ref, axes);
// TransBaseLine(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
......@@ -277,11 +275,11 @@ class TransposeTest : public ::testing::Test {
void InitHalfInput() {
x_half_.Resize(lite::DDim(x_ref_.dims()));
auto X_half__data = x_half_.mutable_data<half>();
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
X_half__data[i] = half(lite::float16(x_ref_.data<float>()[i]));
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(X_half__data, x_gpu_.dims());
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
}
void RunBaseLine(const lite::Tensor* x, lite::Tensor* out) {
......@@ -355,15 +353,15 @@ TEST_F(TransposeTest, TestFP16) {
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* Out_gpu__data = out_gpu_.data<half>();
half* Out_cpu__data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(Out_cpu__data,
Out_gpu__data,
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_cpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(Out_cpu__data[i]));
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册