From 91a58fbae7c0b22f3012ab1d67bee9ccd050e481 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 1 Apr 2020 21:48:23 +0800 Subject: [PATCH] add cuda kernels. test=develop (#3315) add cuda kernel. abs, tanh, elementwise_sub --- lite/api/paddle_place.h | 3 +- lite/backends/cuda/math/utils.h | 2 + lite/kernels/cuda/CMakeLists.txt | 6 +- lite/kernels/cuda/abs_compute.cu | 71 ++++++++++++++++++++++++ lite/kernels/cuda/abs_compute.h | 34 ++++++++++++ lite/kernels/cuda/abs_compute_test.cc | 71 ++++++++++++++++++++++++ lite/kernels/cuda/elementwise_compute.cu | 43 ++++++++++++++ lite/kernels/cuda/elementwise_compute.h | 18 ++++++ lite/kernels/cuda/tanh_compute.cu | 56 +++++++++++++++++++ lite/kernels/cuda/tanh_compute.h | 35 ++++++++++++ lite/kernels/cuda/tanh_compute_test.cc | 70 +++++++++++++++++++++++ lite/operators/activation_ops.cc | 4 ++ 12 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 lite/kernels/cuda/abs_compute.cu create mode 100644 lite/kernels/cuda/abs_compute.h create mode 100644 lite/kernels/cuda/abs_compute_test.cc create mode 100644 lite/kernels/cuda/tanh_compute.cu create mode 100644 lite/kernels/cuda/tanh_compute.h create mode 100644 lite/kernels/cuda/tanh_compute_test.cc diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 1de46a3946..e48686b913 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -99,7 +99,8 @@ enum class ActivationType : int { kTanh = 6, kSwish = 7, kExp = 8, - NUM = 9, + kAbs = 9, + NUM = 10, }; static size_t PrecisionTypeLength(PrecisionType type) { diff --git a/lite/backends/cuda/math/utils.h b/lite/backends/cuda/math/utils.h index b6aa9c7d16..78aa689ff7 100644 --- a/lite/backends/cuda/math/utils.h +++ b/lite/backends/cuda/math/utils.h @@ -29,6 +29,7 @@ enum class BinaryOperation { kADD = 0, kMUL = 1, kDIV = 2, + kSUB = 3, }; template @@ -41,6 +42,7 @@ __device__ __forceinline__ float binary_calc(float x, if (type == BinaryOperation::kADD) return x + y; if (type == BinaryOperation::kMUL) return x * y; if (type == BinaryOperation::kDIV) return x / y; + if (type == BinaryOperation::kSUB) return x - y; } template diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 3fb3136bfc..0fb3c2ea7a 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -8,6 +8,8 @@ add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_de add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps}) add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps}) @@ -45,6 +47,8 @@ lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_ #nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) +nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda) +nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda) nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) @@ -61,7 +65,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc #nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) #nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda) -#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda sequence_topk_avg_pooling_compute_cuda) +#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda) #nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda) if(LITE_BUILD_EXTRA) diff --git a/lite/kernels/cuda/abs_compute.cu b/lite/kernels/cuda/abs_compute.cu new file mode 100644 index 0000000000..4f00aacc0c --- /dev/null +++ b/lite/kernels/cuda/abs_compute.cu @@ -0,0 +1,71 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/abs_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void AbsKernel(const int num, const T* input, T* output); + +template <> +__global__ void AbsKernel(const int num, + const float* input, + float* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = fabsf(input[index]); + } +} + +template <> +__global__ void AbsKernel(const int num, + const double* input, + double* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = fabs(input[index]); + } +} + +void AbsCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->data(); + auto output = param.Out->mutable_data(TARGET(kCUDA)); + + const int threads = 512; + const int blocks = (num + threads - 1) / threads; + AbsKernel<<>>(num, input, output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + abs, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::AbsCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/abs_compute.h b/lite/kernels/cuda/abs_compute.h new file mode 100644 index 0000000000..d1f8a0cc5a --- /dev/null +++ b/lite/kernels/cuda/abs_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class AbsCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~AbsCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/abs_compute_test.cc b/lite/kernels/cuda/abs_compute_test.cc new file mode 100644 index 0000000000..bfbcae56fa --- /dev/null +++ b/lite/kernels/cuda/abs_compute_test.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/abs_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(abs, normal) { + AbsCompute abs_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam param; + + Tensor x, y, x_cpu, y_cpu; + int h = 3, w = 3; + y.Resize({h, w}); + x_cpu.Resize({h, w}); + y_cpu.Resize({h, w}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i - 1.5; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + abs_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + abs_kernel.SetContext(std::move(ctx)); + abs_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], std::fabs(x_cpu_data[i]), 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_compute.cu b/lite/kernels/cuda/elementwise_compute.cu index 64759f86f5..02b7c8f7d9 100644 --- a/lite/kernels/cuda/elementwise_compute.cu +++ b/lite/kernels/cuda/elementwise_compute.cu @@ -152,6 +152,18 @@ void ElementwiseAddComputeNHWC::Run() { if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } +void ElementwiseSubCompute::Run() { + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseSubComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + void ElementwiseMulCompute::Run() { ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false) cudaError_t error = cudaGetLastError(); @@ -204,6 +216,17 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); +REGISTER_LITE_KERNEL(elementwise_sub, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseSubCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_add, kCUDA, kFloat, @@ -224,6 +247,26 @@ REGISTER_LITE_KERNEL(elementwise_add, DATALAYOUT(kNHWC))}) .Finalize(); +REGISTER_LITE_KERNEL(elementwise_sub, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseSubComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kCUDA, kFloat, diff --git a/lite/kernels/cuda/elementwise_compute.h b/lite/kernels/cuda/elementwise_compute.h index 986a4db227..bc9ffd5d27 100644 --- a/lite/kernels/cuda/elementwise_compute.h +++ b/lite/kernels/cuda/elementwise_compute.h @@ -38,6 +38,24 @@ class ElementwiseAddComputeNHWC virtual ~ElementwiseAddComputeNHWC() = default; }; +class ElementwiseSubCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseSubCompute() = default; +}; + +class ElementwiseSubComputeNHWC + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseSubComputeNHWC() = default; +}; + class ElementwiseMulCompute : public KernelLite { public: diff --git a/lite/kernels/cuda/tanh_compute.cu b/lite/kernels/cuda/tanh_compute.cu new file mode 100644 index 0000000000..4f9e2729a7 --- /dev/null +++ b/lite/kernels/cuda/tanh_compute.cu @@ -0,0 +1,56 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/tanh_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void TanhKernel(const int num, const T* input, T* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = tanh(input[index]); + } +} + +void TanhCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->data(); + auto output = param.Out->mutable_data(TARGET(kCUDA)); + + const int threads = 512; + const int blocks = (num + threads - 1) / threads; + TanhKernel<<>>(num, input, output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + tanh, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::TanhCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/tanh_compute.h b/lite/kernels/cuda/tanh_compute.h new file mode 100644 index 0000000000..b23b27882c --- /dev/null +++ b/lite/kernels/cuda/tanh_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class TanhCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~TanhCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/tanh_compute_test.cc b/lite/kernels/cuda/tanh_compute_test.cc new file mode 100644 index 0000000000..7bc8f25df0 --- /dev/null +++ b/lite/kernels/cuda/tanh_compute_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/tanh_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(tanh, fp32) { + TanhCompute tanh_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam param; + + Tensor x, y, x_cpu, y_cpu; + int h = 3, w = 3; + y.Resize({h, w}); + x_cpu.Resize({h, w}); + y_cpu.Resize({h, w}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i - 1.5; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + tanh_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + tanh_kernel.SetContext(std::move(ctx)); + tanh_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], tanh(x_cpu_data[i]), 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index abaaa1a705..13abe0c53e 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -71,6 +71,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { } else if (opdesc.Type() == "exp") { // exp param_.active_type = lite_api::ActivationType::kExp; + } else if (opdesc.Type() == "abs") { + // abs + param_.active_type = lite_api::ActivationType::kAbs; } VLOG(4) << "opdesc.Type():" << opdesc.Type(); @@ -92,6 +95,7 @@ REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp); -- GitLab