未验证 提交 cb138726 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Add fc cuda kernel. (#3873)

上级 0edd6cf1
......@@ -40,5 +40,4 @@ REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU), TARGET(kX86)})
.ExcludeTargets({TARGET(kBM)})
.ExcludeTargets({TARGET(kCUDA)})
.BindKernel("fc");
......@@ -6,6 +6,7 @@ message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
......@@ -65,7 +66,8 @@ nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute
nv_test(elementwise_compute_cuda_test SRCS elementwise_compute_test.cc DEPS elementwise_compute_cuda)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
#nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/fc_compute.h"
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
struct FcTypeTraits;
template <>
struct FcTypeTraits<float> {
typedef float4 Type;
};
template <typename T>
__global__ void bias_v4(const int num, const T* bias, T* data, int K) {
CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K;
const T bias_ptr = bias[bias_idx];
const T in_ptr = data[index];
T packed_val;
packed_val.x = in_ptr.x + bias_ptr.x;
packed_val.y = in_ptr.y + bias_ptr.y;
packed_val.z = in_ptr.z + bias_ptr.z;
packed_val.w = in_ptr.w + bias_ptr.w;
data[index] = packed_val;
}
}
template <typename T>
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K;
const T bias_ptr = bias[bias_idx];
const T in_ptr = data[index];
T packed_val;
packed_val.x = fmaxf(0.f, in_ptr.x + bias_ptr.x);
packed_val.y = fmaxf(0.f, in_ptr.y + bias_ptr.y);
packed_val.z = fmaxf(0.f, in_ptr.z + bias_ptr.z);
packed_val.w = fmaxf(0.f, in_ptr.w + bias_ptr.w);
data[index] = packed_val;
}
}
template <typename T>
__global__ void general_bias(const int num, const T* bias, T* data) {
int offset = blockIdx.x * num;
for (int i = threadIdx.x; i < num; i += blockDim.x) {
T temp;
#if __CUDA_ARCH__ >= 350
temp = __ldg(data + offset + i) + __ldg(bias + i);
#else
temp = data[offset + i] + bias[i];
#endif
data[offset + i] = temp;
}
}
template <typename T>
__global__ void general_relu_bias(const int num, const T* bias, T* data) {
int offset = blockIdx.x * num;
for (int i = threadIdx.x; i < num; i += blockDim.x) {
T temp;
#if __CUDA_ARCH__ >= 350
temp = __ldg(data + offset + i) + __ldg(bias + i);
#else
temp = data[offset + i] + bias[i];
#endif
data[offset + i] = static_cast<int>(temp > 0) * temp;
}
}
template <typename T, PrecisionType PType>
void FcCompute<T, PType>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
}
template <typename T, PrecisionType PType>
void FcCompute<T, PType>::Run() {
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto& param = this->template Param<param_t>();
const auto* x_data = param.input->template data<T>();
const auto* w_data = param.w->template data<T>();
const auto* b_data = param.bias ? param.bias->template data<T>() : nullptr;
auto out_vec = param.output->dims().Vectorize();
out_vec.back() = param.w->dims()[1];
param.output->Resize(out_vec);
auto* out_data = param.output->template mutable_data<T>(TARGET(kCUDA));
int in_num_col_dims = param.in_num_col_dims;
int M = static_cast<int>(
param.input->dims().Slice(0, param.in_num_col_dims).production());
int K = static_cast<int>(
param.input->dims()
.Slice(param.in_num_col_dims, param.input->dims().size())
.production());
int K2 = static_cast<int>(param.w->dims()[0]);
int N = static_cast<int>(param.w->dims()[1]);
CHECK_EQ(K, K2) << "x_w must be equal with y_h";
CHECK(gemm_impl_->init(false, false, M, N, K, &context));
gemm_impl_->run(1.0f, 0.0f, x_data, w_data, out_data, &context);
if (b_data == nullptr) {
return;
}
std::string activation_type = param.activation_type;
if (N % 4 == 0) {
const int threads = 256;
const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<T>::Type trans_type;
const auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(b_data);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(out_data);
if (activation_type == "relu") {
bias_relu_v4<trans_type><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else if (activation_type == "") {
bias_v4<trans_type><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else {
LOG(FATAL) << "not supported activation type: " << activation_type;
}
} else {
const int threads = 256;
const int blocks = M;
if (activation_type == "relu") {
general_relu_bias<T><<<blocks, threads, 0, stream>>>(N, b_data, out_data);
} else if (activation_type == "") {
general_bias<T><<<blocks, threads, 0, stream>>>(N, b_data, out_data);
} else {
LOG(FATAL) << "not supported activation type: " << activation_type;
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using FcFp32 = paddle::lite::kernels::cuda::FcCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(fc, kCUDA, kFloat, kNCHW, FcFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType PType>
class FcCompute : public KernelLite<TARGET(kCUDA), PType> {
public:
using param_t = operators::FcParam;
void PrepareForRun() override;
void Run() override;
virtual ~FcCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<T, T>> gemm_impl_{nullptr};
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/fc_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class FcTest : public ::testing::Test {
protected:
FcTest()
: m(128),
k(512),
n(64),
in_num_col_dims(1),
act_type("relu"),
x_shape({m, k}),
w_shape({k, n}),
b_shape({n}),
out_shape({m, n}) {
X_gpu.Resize(lite::DDim(x_shape));
X_ref.Resize(lite::DDim(x_shape));
W_gpu.Resize(lite::DDim(w_shape));
W_ref.Resize(lite::DDim(w_shape));
b_gpu.Resize(lite::DDim(b_shape));
b_ref.Resize(lite::DDim(b_shape));
auto x_ref_data = X_ref.mutable_data<float>();
auto w_ref_data = W_ref.mutable_data<float>();
auto b_ref_data = b_ref.mutable_data<float>();
// prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
for (int64_t i = 0; i < W_ref.numel(); i++) {
w_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
for (int64_t i = 0; i < b_ref.numel(); i++) {
b_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
Out_ref.Resize(lite::DDim(out_shape));
Out_cpu.Resize(Out_ref.dims());
Out_gpu.Resize(Out_ref.dims());
fc_cpu_base(&X_ref, &W_ref, &b_ref, &Out_ref);
device_init();
}
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
param.input = &X_gpu;
param.w = &W_gpu;
param.bias = &b_gpu;
param.in_num_col_dims = in_num_col_dims;
param.activation_type = act_type;
param.output = &Out_gpu;
}
void float_data_init() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(),
X_gpu.dims());
W_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(W_ref.data<float>(),
W_gpu.dims());
b_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(b_ref.data<float>(),
b_gpu.dims());
}
void half_data_init() {
X_half.Resize(lite::DDim(x_shape));
auto x_half_data = X_half.mutable_data<half>();
for (int64_t i = 0; i < X_half.numel(); i++) {
x_half_data[i] = half(lite::float16(X_ref.data<float>()[i]));
}
X_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, X_gpu.dims());
W_half.Resize(W_ref.dims());
auto w_half_data = W_half.mutable_data<half>();
for (int64_t i = 0; i < W_half.numel(); i++) {
w_half_data[i] = half(lite::float16(W_ref.data<float>()[i]));
}
W_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(w_half_data, W_gpu.dims());
b_half.Resize(b_ref.dims());
auto b_half_data = b_half.mutable_data<half>();
for (int64_t i = 0; i < b_half.numel(); i++) {
b_half_data[i] = half(lite::float16(b_ref.data<float>()[i]));
}
b_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(b_half_data, b_gpu.dims());
}
void fc_cpu_base(const lite::Tensor* X,
const lite::Tensor* W,
const lite::Tensor* b,
lite::Tensor* Out) {
const float* data_in = X->data<float>();
const float* bias = b->data<float>();
const float* weights = W->data<float>();
float* data_out = Out->mutable_data<float>();
int out_rows = X->dims()[0];
int in_cols = X->numel() / out_rows;
int out_cols = W->numel() / in_cols;
int index_out;
for (int i = 0; i < out_rows; i++) {
for (int j = 0; j < out_cols; j++) {
index_out = i * out_cols + j;
data_out[index_out] = bias ? bias[j] : 0;
for (int k = 0; k < in_cols; k++) {
data_out[index_out] +=
data_in[i * in_cols + k] * weights[k * out_cols + j];
}
if (act_type == "relu") {
data_out[index_out] *= static_cast<int>(data_out[index_out] > 0);
}
}
}
}
int m, k, n, in_num_col_dims;
std::string act_type;
std::vector<int64_t> x_shape, w_shape, b_shape, out_shape;
lite::Tensor X_ref, W_ref, b_ref, Out_ref;
lite::Tensor X_gpu, W_gpu, b_gpu;
lite::Tensor X_half, W_half, b_half;
lite::Tensor Out_cpu, Out_gpu;
operators::FcParam param;
std::unique_ptr<KernelContext> ctx;
cudaStream_t stream;
};
TEST_F(FcTest, TestFP32) {
float_data_init();
FcCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(),
Out_gpu.data<float>(),
sizeof(float) * Out_gpu.numel(),
IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) {
float res = Out_cpu.data<float>()[i];
float ref = Out_ref.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册