From 14397ca02a691e6fc9f783cbfdf2a974e8ebd0e2 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 16 Jul 2020 19:02:12 +0800 Subject: [PATCH] [CUDA] [Kernel] Add sigmoid cuda kernel. (#3955) --- lite/backends/cuda/math/activation.cu | 71 +++++++++ lite/backends/cuda/math/activation.h | 3 + lite/kernels/cuda/CMakeLists.txt | 2 + lite/kernels/cuda/concat_compute_test.cc | 2 +- lite/kernels/cuda/sigmoid_compute.cu | 57 ++++++++ lite/kernels/cuda/sigmoid_compute.h | 35 +++++ lite/kernels/cuda/sigmoid_compute_test.cc | 168 ++++++++++++++++++++++ 7 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 lite/kernels/cuda/sigmoid_compute.cu create mode 100644 lite/kernels/cuda/sigmoid_compute.h create mode 100644 lite/kernels/cuda/sigmoid_compute_test.cc diff --git a/lite/backends/cuda/math/activation.cu b/lite/backends/cuda/math/activation.cu index 7524fbc4fb..4d97042aeb 100644 --- a/lite/backends/cuda/math/activation.cu +++ b/lite/backends/cuda/math/activation.cu @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "lite/backends/cuda/cuda_utils.h" #include "lite/backends/cuda/math/activation.h" #include "lite/backends/cuda/math/utils.h" @@ -484,6 +485,76 @@ template void relu(int, const half*, half*, float, cudaStream_t); template void bias_relu( int, const float*, const float* bias, float*, float, cudaStream_t); +// ------------- sigmoid ------------- + +template +__global__ void sigmoid_kernel(const int num, const T* in, T* out) { + CUDA_KERNEL_LOOP(i, num) { +#if __CUDA_ARCH__ >= 350 + out[i] = static_cast(1.0f) / + (static_cast(1.0f) + expf(-1 * __ldg(in + i))); +#else + out[i] = static_cast(1.0f) / (static_cast(1.0f) + expf(-in[i])); +#endif + } +} + +template <> +__global__ void sigmoid_kernel(const int num, const half* in, half* out) { + CUDA_KERNEL_LOOP(i, num) { + half tmp = __float2half(1.0f); +#if __CUDA_ARCH__ >= 530 + out[i] = __hdiv( + tmp, __hadd(tmp, hexp(__hmul(__float2half(-1.0f), __ldg(in + i))))); +#else + out[i] = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i])))); +#endif + } +} + +template <> +__global__ void sigmoid_kernel(const int num, const half2* in, half2* out) { + CUDA_KERNEL_LOOP(i, num) { + half2 tmp = __floats2half2_rn(1.0f, 1.0f); +#if __CUDA_ARCH__ >= 530 + out[i] = __h2div(tmp, + __hadd2(tmp, + h2exp(__hmul2(__floats2half2_rn(-1.0f, -1.0f), + __ldg(in + i))))); +#else + out[i].x = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].x)))); + out[i].y = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].y)))); +#endif + } +} + +template +void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream) { + sigmoid_kernel<<>>( + num, din, dout); + CUDA_POST_KERNEL_CHECK; +} + +template <> +void sigmoid(const int num, const half* din, half* dout, cudaStream_t stream) { + if (num % 2 == 0) { + const half2* din2 = reinterpret_cast(din); + half2* dout2 = reinterpret_cast(dout); + sigmoid_kernel< + half2><<>>( + num / 2, din2, dout2); + } else { + sigmoid_kernel<<>>( + num, din, dout); + } + CUDA_POST_KERNEL_CHECK; +} + +template void sigmoid(const int num, + const float* din, + float* dout, + cudaStream_t stream); + } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/backends/cuda/math/activation.h b/lite/backends/cuda/math/activation.h index 0150a32865..926ad8d99f 100644 --- a/lite/backends/cuda/math/activation.h +++ b/lite/backends/cuda/math/activation.h @@ -83,6 +83,9 @@ void bias_int8_nhwc(int num, const void* scale, cudaStream_t stream); +template +void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream); + } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 76e2d1545e..3d396cfa12 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -15,6 +15,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${ add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps}) add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps}) add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(sigmoid_compute_cuda CUDA basic SRCS sigmoid_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS ${lite_kernel_deps}) @@ -61,6 +62,7 @@ nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_ nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda) nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda) nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda) +nv_test(sigmoid_compute_cuda_test SRCS sigmoid_compute_test.cc DEPS sigmoid_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda) diff --git a/lite/kernels/cuda/concat_compute_test.cc b/lite/kernels/cuda/concat_compute_test.cc index cc12fcd289..08dd4013a5 100644 --- a/lite/kernels/cuda/concat_compute_test.cc +++ b/lite/kernels/cuda/concat_compute_test.cc @@ -69,7 +69,7 @@ void concat_compute_ref(const operators::ConcatParam& param) { std::vector input_cols(input.size()); for (int i = 0; i < num; ++i) { int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1; - for (int didx = 0; didx < input[i]->dims().size(); ++didx) { + for (size_t didx = 0; didx < input[i]->dims().size(); ++didx) { input_i_numel *= input[i]->dims()[didx]; } int t_cols = input_i_numel / rows; diff --git a/lite/kernels/cuda/sigmoid_compute.cu b/lite/kernels/cuda/sigmoid_compute.cu new file mode 100644 index 0000000000..2879f50b4d --- /dev/null +++ b/lite/kernels/cuda/sigmoid_compute.cu @@ -0,0 +1,57 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/activation.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/sigmoid_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void SigmoidCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->template data(); + auto output = param.Out->template mutable_data(TARGET(kCUDA)); + + lite::cuda::math::sigmoid(num, input, output, stream); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +using SigmoidFp32 = + paddle::lite::kernels::cuda::SigmoidCompute; + +using SigmoidFp16 = + paddle::lite::kernels::cuda::SigmoidCompute; + +REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFloat, kNCHW, SigmoidFp32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFP16, kNCHW, SigmoidFp16, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .Finalize(); diff --git a/lite/kernels/cuda/sigmoid_compute.h b/lite/kernels/cuda/sigmoid_compute.h new file mode 100644 index 0000000000..455dc38d1f --- /dev/null +++ b/lite/kernels/cuda/sigmoid_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +class SigmoidCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~SigmoidCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sigmoid_compute_test.cc b/lite/kernels/cuda/sigmoid_compute_test.cc new file mode 100644 index 0000000000..e27904333b --- /dev/null +++ b/lite/kernels/cuda/sigmoid_compute_test.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/sigmoid_compute.h" + +#include + +#include +#include +#include +#include + +#include "lite/api/test_helper.h" +#include "lite/backends/cuda/target_wrapper.h" +#include "lite/utils/float16.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SigmoidTest : public ::testing::Test { + protected: + SigmoidTest() : m_(8), n_(64), shape_({m_, n_}) { + x_ref_.Resize(lite::DDim(shape_)); + x_gpu_.Resize(lite::DDim(shape_)); + + auto x_ref_data = x_ref_.mutable_data(); + + for (int64_t i = 0; i < x_ref_.numel(); i++) { + x_ref_data[i] = static_cast(i % 10 * 0.2); + } + + out_ref_.Resize(lite::DDim(shape_)); + out_cpu_.Resize(out_ref_.dims()); + out_gpu_.Resize(out_ref_.dims()); + RunBaseLine(); + + InitParamAndContext(); + } + + void InitParamAndContext() { + ctx_.reset(new KernelContext); + cudaStreamCreate(&stream_); + auto& context = ctx_->As(); + context.SetExecStream(stream_); + param_.X = &x_gpu_; + param_.Out = &out_gpu_; + } + + void InitFloatInput() { + x_gpu_.Assign(x_ref_.data(), + x_gpu_.dims()); + } + + void InitHalfInput() { + x_half_.Resize(lite::DDim(shape_)); + auto x_half_data = x_half_.mutable_data(); + for (int64_t i = 0; i < x_half_.numel(); i++) { + x_half_data[i] = half(lite::float16(x_ref_.data()[i])); + } + x_gpu_.Assign(x_half_data, x_gpu_.dims()); + } + + void RunBaseLine() { + for (int64_t i = 0; i < x_ref_.numel(); ++i) { + out_ref_.mutable_data()[i] = + 1.f / (1.f + expf(-1 * x_ref_.data()[i])); + } + } + + int m_, n_; + std::vector shape_; + lite::Tensor x_ref_, out_ref_; + lite::Tensor x_gpu_; + lite::Tensor x_half_; + lite::Tensor out_cpu_, out_gpu_; + + operators::ActivationParam param_; + std::unique_ptr ctx_; + cudaStream_t stream_; +}; + +TEST_F(SigmoidTest, TestFP32) { + InitFloatInput(); + SigmoidCompute kernel; + kernel.SetParam(param_); + kernel.SetContext(std::move(ctx_)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp32, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + CopySync(out_cpu_.mutable_data(), + out_gpu_.data(), + sizeof(float) * out_gpu_.numel(), + IoDirection::DtoH); + + for (int i = 0; i < out_gpu_.numel(); ++i) { + float res = out_cpu_.data()[i]; + float ref = out_ref_.data()[i]; + EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5); + } +} + +TEST_F(SigmoidTest, TestFP16) { + InitHalfInput(); + SigmoidCompute kernel; + kernel.SetParam(param_); + kernel.SetContext(std::move(ctx_)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp16, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + const half* out_gpu_data = out_gpu_.data(); + half* out_cpu_data = out_cpu_.mutable_data(); + CopySync(out_cpu_data, + out_gpu_data, + sizeof(half) * out_gpu_.numel(), + IoDirection::DtoH); + + for (int i = 0; i < out_gpu_.numel(); ++i) { + float res = static_cast(lite::float16(out_cpu_data[i])); + float ref = out_ref_.data()[i]; + EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 2e-2); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle -- GitLab