未验证 提交 14397ca0 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Add sigmoid cuda kernel. (#3955)

上级 bebe26e5
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/utils.h"
......@@ -484,6 +485,76 @@ template void relu(int, const half*, half*, float, cudaStream_t);
template void bias_relu(
int, const float*, const float* bias, float*, float, cudaStream_t);
// ------------- sigmoid -------------
template <typename T>
__global__ void sigmoid_kernel(const int num, const T* in, T* out) {
CUDA_KERNEL_LOOP(i, num) {
#if __CUDA_ARCH__ >= 350
out[i] = static_cast<T>(1.0f) /
(static_cast<T>(1.0f) + expf(-1 * __ldg(in + i)));
#else
out[i] = static_cast<T>(1.0f) / (static_cast<T>(1.0f) + expf(-in[i]));
#endif
}
}
template <>
__global__ void sigmoid_kernel(const int num, const half* in, half* out) {
CUDA_KERNEL_LOOP(i, num) {
half tmp = __float2half(1.0f);
#if __CUDA_ARCH__ >= 530
out[i] = __hdiv(
tmp, __hadd(tmp, hexp(__hmul(__float2half(-1.0f), __ldg(in + i)))));
#else
out[i] = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i]))));
#endif
}
}
template <>
__global__ void sigmoid_kernel(const int num, const half2* in, half2* out) {
CUDA_KERNEL_LOOP(i, num) {
half2 tmp = __floats2half2_rn(1.0f, 1.0f);
#if __CUDA_ARCH__ >= 530
out[i] = __h2div(tmp,
__hadd2(tmp,
h2exp(__hmul2(__floats2half2_rn(-1.0f, -1.0f),
__ldg(in + i)))));
#else
out[i].x = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].x))));
out[i].y = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].y))));
#endif
}
}
template <typename T>
void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream) {
sigmoid_kernel<T><<<CUDA_GET_BLOCKS(num), CUDA_NUM_THREADS, 0, stream>>>(
num, din, dout);
CUDA_POST_KERNEL_CHECK;
}
template <>
void sigmoid(const int num, const half* din, half* dout, cudaStream_t stream) {
if (num % 2 == 0) {
const half2* din2 = reinterpret_cast<const half2*>(din);
half2* dout2 = reinterpret_cast<half2*>(dout);
sigmoid_kernel<
half2><<<CUDA_GET_BLOCKS(num / 2), CUDA_NUM_THREADS, 0, stream>>>(
num / 2, din2, dout2);
} else {
sigmoid_kernel<half><<<CUDA_GET_BLOCKS(num), CUDA_NUM_THREADS, 0, stream>>>(
num, din, dout);
}
CUDA_POST_KERNEL_CHECK;
}
template void sigmoid(const int num,
const float* din,
float* dout,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -83,6 +83,9 @@ void bias_int8_nhwc(int num,
const void* scale,
cudaStream_t stream);
template <typename T>
void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -15,6 +15,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${
add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps})
add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sigmoid_compute_cuda CUDA basic SRCS sigmoid_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS ${lite_kernel_deps})
......@@ -61,6 +62,7 @@ nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_
nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda)
nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(sigmoid_compute_cuda_test SRCS sigmoid_compute_test.cc DEPS sigmoid_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda)
......
......@@ -69,7 +69,7 @@ void concat_compute_ref(const operators::ConcatParam& param) {
std::vector<int> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1;
for (int didx = 0; didx < input[i]->dims().size(); ++didx) {
for (size_t didx = 0; didx < input[i]->dims().size(); ++didx) {
input_i_numel *= input[i]->dims()[didx];
}
int t_cols = input_i_numel / rows;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sigmoid_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
void SigmoidCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->template data<T>();
auto output = param.Out->template mutable_data<T>(TARGET(kCUDA));
lite::cuda::math::sigmoid<T>(num, input, output, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SigmoidFp32 =
paddle::lite::kernels::cuda::SigmoidCompute<float, PRECISION(kFloat)>;
using SigmoidFp16 =
paddle::lite::kernels::cuda::SigmoidCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFloat, kNCHW, SigmoidFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFP16, kNCHW, SigmoidFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SigmoidCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~SigmoidCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sigmoid_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SigmoidTest : public ::testing::Test {
protected:
SigmoidTest() : m_(8), n_(64), shape_({m_, n_}) {
x_ref_.Resize(lite::DDim(shape_));
x_gpu_.Resize(lite::DDim(shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
out_ref_.Resize(lite::DDim(shape_));
out_cpu_.Resize(out_ref_.dims());
out_gpu_.Resize(out_ref_.dims());
RunBaseLine();
InitParamAndContext();
}
void InitParamAndContext() {
ctx_.reset(new KernelContext);
cudaStreamCreate(&stream_);
auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream_);
param_.X = &x_gpu_;
param_.Out = &out_gpu_;
}
void InitFloatInput() {
x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
x_gpu_.dims());
}
void InitHalfInput() {
x_half_.Resize(lite::DDim(shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
}
void RunBaseLine() {
for (int64_t i = 0; i < x_ref_.numel(); ++i) {
out_ref_.mutable_data<float>()[i] =
1.f / (1.f + expf(-1 * x_ref_.data<float>()[i]));
}
}
int m_, n_;
std::vector<int64_t> shape_;
lite::Tensor x_ref_, out_ref_;
lite::Tensor x_gpu_;
lite::Tensor x_half_;
lite::Tensor out_cpu_, out_gpu_;
operators::ActivationParam param_;
std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream_;
};
TEST_F(SigmoidTest, TestFP32) {
InitFloatInput();
SigmoidCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
out_gpu_.data<float>(),
sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = out_cpu_.data<float>()[i];
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5);
}
}
TEST_F(SigmoidTest, TestFP16) {
InitHalfInput();
SigmoidCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 2e-2);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册