From da3285944fd62d002d5266d420d3cd23638e9b27 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Mon, 9 Sep 2019 09:08:06 +0800 Subject: [PATCH] add calib cuda kernel. (#1977) * add calib cuda kernel. * add unit test for calib cuda kernel. test=develop --- lite/core/op_registry.cc | 1 + lite/kernels/cuda/CMakeLists.txt | 3 + lite/kernels/cuda/calib_compute.cu | 131 ++++++++++++++ lite/kernels/cuda/calib_compute.h | 52 ++++++ lite/kernels/cuda/calib_compute_cuda_test.cc | 178 +++++++++++++++++++ 5 files changed, 365 insertions(+) create mode 100644 lite/kernels/cuda/calib_compute.cu create mode 100644 lite/kernels/cuda/calib_compute.h create mode 100644 lite/kernels/cuda/calib_compute_cuda_test.cc diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 80dd5a4cfc..53d4afa9ff 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -105,6 +105,7 @@ KernelRegistry::KernelRegistry() DATALAYOUT(layout__)>::Global()); // Currently, just register 2 kernel targets. INIT_FOR(kCUDA, kFloat, kNCHW); + INIT_FOR(kCUDA, kInt8, kNCHW); INIT_FOR(kCUDA, kAny, kNCHW); INIT_FOR(kCUDA, kAny, kAny); INIT_FOR(kCUDA, kInt8, kNHWC); diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index f94e02f3b1..a0c79465ec 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -17,6 +17,9 @@ nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEP nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) +nv_library(calib_compute_cuda SRCS calib_compute.cu DEPS ${lite_kernel_deps}) +lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) + set(cuda_kernels conv2d_cuda mul_compute_cuda diff --git a/lite/kernels/cuda/calib_compute.cu b/lite/kernels/cuda/calib_compute.cu new file mode 100644 index 0000000000..04f199e91f --- /dev/null +++ b/lite/kernels/cuda/calib_compute.cu @@ -0,0 +1,131 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/kernels/cuda/calib_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +__device__ __forceinline__ int8_t float2int8(float x) { + x = fmaxf(x, INT8_MIN); + x = fminf(x, INT8_MAX); + return __float2int_rn(x); +} + +__global__ void Fp32ToInt8Kernel(const int num, + const float scale, + const float* input, + int8_t* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = float2int8(input[index] / scale); + } +} + +__global__ void Int8ToFp32Kernel(const int num, + const float scale, + const int8_t* input, + float* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = input[index] * scale; + } +} + +void CalibComputeFp32ToInt8::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + auto stream = ctx.exec_stream(); + + auto scale = param.scale; + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(TARGET(kCUDA)); + int num = static_cast(param.input->numel()); + int threads = 1024; + int blocks = (num + threads - 1) / threads; + Fp32ToInt8Kernel<<>>(num, scale, din, dout); + cudaError_t error = cudaGetLastError(); + CHECK(error == cudaSuccess) << cudaGetErrorString(error); +} + +void CalibComputeInt8ToFp32::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + auto stream = ctx.exec_stream(); + + auto scale = param.scale; + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(TARGET(kCUDA)); + int num = static_cast(param.input->numel()); + int threads = 1024; + int blocks = (num + threads - 1) / threads; + Int8ToFp32Kernel<<>>(num, scale, din, dout); + cudaError_t error = cudaGetLastError(); + CHECK(error == cudaSuccess) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(calib, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))}) + .Finalize(); + +REGISTER_LITE_KERNEL(calib, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .Finalize(); + +REGISTER_LITE_KERNEL(calib_once, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))}) + .Finalize(); +REGISTER_LITE_KERNEL(calib_once, + kCUDA, + kInt8, + kNCHW, + paddle::lite::kernels::cuda::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/cuda/calib_compute.h b/lite/kernels/cuda/calib_compute.h new file mode 100644 index 0000000000..f161f69992 --- /dev/null +++ b/lite/kernels/cuda/calib_compute.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/calib_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class CalibComputeFp32ToInt8 + : public KernelLite { + public: + using param_t = operators::CalibParam; + + void Run() override; + + virtual ~CalibComputeFp32ToInt8() = default; + + std::string doc() const override { return "Fp32 --> Int8"; } +}; + +class CalibComputeInt8ToFp32 + : public KernelLite { + public: + using param_t = operators::CalibParam; + + void Run() override; + + virtual ~CalibComputeInt8ToFp32() = default; + + std::string doc() const override { return "Int8 --> Fp32"; } +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/calib_compute_cuda_test.cc b/lite/kernels/cuda/calib_compute_cuda_test.cc new file mode 100644 index 0000000000..691b52d257 --- /dev/null +++ b/lite/kernels/cuda/calib_compute_cuda_test.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +static void int8_to_fp32_basic(const int8_t* din, + float* dout, + const float scale, + int num) { + for (int j = 0; j < num; ++j) { + dout[j] = din[j] * scale; + } +} + +static void fp32_to_int8_basic(const float* din, + int8_t* dout, + const float scale, + int num) { + for (int j = 0; j < num; ++j) { + auto v = din[j] / scale; + v = std::max(v, static_cast(INT8_MIN)); + v = std::min(v, static_cast(INT8_MAX)); + v = roundf(v); + dout[j] = static_cast(v); + } +} + +void calib_ref(const operators::CalibParam& param, bool to_float = true) { + auto scale = param.scale; + if (to_float) { + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(); + int8_to_fp32_basic(din, dout, scale, param.input->numel()); + } else { + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(); + fp32_to_int8_basic(din, dout, scale, param.input->numel()); + } +} + +TEST(calib_cuda, int8_to_fp32) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + auto calib = std::move(*std::next(kernels.begin(), 1)); + LOG(INFO) << "get kernel: " << calib->doc(); + const int n = 64, c = 32, h = 18, w = 18; + Tensor x; + Tensor x_cpu; + Tensor output; + Tensor output_cpu; + // set the dims of input, output tensors + x.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_cpu.Resize({n, c, h, w}); + // initialize the data of input tensors + auto* x_data = x.mutable_data(TARGET(kCUDA)); + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_cpu_data[i] = static_cast(sign * (i % 127)); + } + x.Assign(x_cpu_data, x_cpu.dims()); + // prepare kernel params and run + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + calib->SetContext(std::move(ctx)); + + operators::CalibParam param; + param.scale = 0.013f; + param.input = &x; + param.output = &output; + calib->SetParam(param); + calib->Launch(); + cudaDeviceSynchronize(); + // invoking ref implementation and compare results + param.input = &x_cpu; + param.output = &output_cpu; + calib_ref(param); + auto* output_data = output.mutable_data(); + std::unique_ptr output_gpu_copy(new float[output.numel()]); + CopySync(output_gpu_copy.get(), + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + const auto* output_cpu_data = output_cpu.data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_gpu_copy[i], output_cpu_data[i], 1e-5); + } +} + +TEST(calib_cuda, fp32_to_int8) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + auto calib = std::move(kernels.front()); + LOG(INFO) << "get kernel: " << calib->doc(); + const int n = 64, c = 32, h = 18, w = 18; + Tensor x; + Tensor x_cpu; + Tensor output; + Tensor output_cpu; + // set the dims of input, output tensors + x.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_cpu.Resize({n, c, h, w}); + // initialize the data of input tensors + auto* x_data = x.mutable_data(TARGET(kCUDA)); + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_cpu_data[i] = sign * (i % 127) * 0.013f; + } + x.Assign(x_cpu_data, x_cpu.dims()); + // prepare kernel params and run + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + calib->SetContext(std::move(ctx)); + + operators::CalibParam param; + param.scale = 0.013f; + param.input = &x; + param.output = &output; + calib->SetParam(param); + calib->Launch(); + cudaDeviceSynchronize(); + // invoking ref implementation and compare results + param.input = &x_cpu; + param.output = &output_cpu; + calib_ref(param, false); + auto* output_data = output.mutable_data(); + std::unique_ptr output_gpu_copy(new int8_t[output.numel()]); + CopySync(output_gpu_copy.get(), + output_data, + sizeof(int8_t) * output.numel(), + IoDirection::DtoH); + const auto* output_cpu_data = output_cpu.data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_EQ(output_gpu_copy[i], output_cpu_data[i]); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(calib, kCUDA, kInt8, kNCHW, int8_to_fp32); +USE_LITE_KERNEL(calib, kCUDA, kInt8, kNCHW, fp32_to_int8); -- GitLab