From 7931c7583e6e5a3dcf84ca9d92f3974806b76e3f Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Fri, 30 Aug 2019 16:28:48 +0800 Subject: [PATCH] add nearest_interp_cuda kernel, test=develop (#1920) add nearest_interp cuda kernel for Paddle-Lite --- lite/api/_paddle_use_kernels.h | 1 + lite/kernels/cuda/CMakeLists.txt | 7 +- lite/kernels/cuda/nearest_interp_compute.cu | 160 +++++++++++++++ lite/kernels/cuda/nearest_interp_compute.h | 35 ++++ .../cuda/nearest_interp_compute_test.cc | 186 ++++++++++++++++++ 5 files changed, 388 insertions(+), 1 deletion(-) create mode 100644 lite/kernels/cuda/nearest_interp_compute.cu create mode 100644 lite/kernels/cuda/nearest_interp_compute.h create mode 100644 lite/kernels/cuda/nearest_interp_compute_test.cc diff --git a/lite/api/_paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h index d54caa83e1..a95c9e7c62 100644 --- a/lite/api/_paddle_use_kernels.h +++ b/lite/api/_paddle_use_kernels.h @@ -155,6 +155,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host); USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(nearest_interp, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def); #endif diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 1a198c1dd5..6623894ec7 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -7,14 +7,19 @@ message(STATUS "compile with lite CUDA kernels") nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) + +nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) +lite_cc_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) + lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) lite_cc_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) set(cuda_kernels -mul_compute_cuda +mul_compute_cuda io_copy_compute_cuda leaky_relu_compute_cuda +nearest_interp_compute_cuda yolo_box_compute_cuda ) diff --git a/lite/kernels/cuda/nearest_interp_compute.cu b/lite/kernels/cuda/nearest_interp_compute.cu new file mode 100644 index 0000000000..8edeacfe5a --- /dev/null +++ b/lite/kernels/cuda/nearest_interp_compute.cu @@ -0,0 +1,160 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/nearest_interp_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +__global__ void KeNearestNeighborInterp(const float* in, + const size_t in_img_h, + const size_t in_img_w, + const size_t input_h, + const size_t input_w, + float* out, + const size_t out_img_h, + const size_t out_img_w, + const size_t output_h, + const size_t output_w, + const size_t num_channels, + const float ratio_h, + const float ratio_w, + const bool align_corners) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); + + int out_img_idx = tid % out_img_w; + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); + + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + } +} + +void NearestInterpCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + Tensor* input = param.X; + Tensor* output = param.Out; + Tensor* out_size = param.OutSize; + + auto* input_data = input->data(); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + int out_h = param.out_h; + int out_w = param.out_w; + float scale = param.scale; + bool align_corners = param.align_corners; + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + if (out_size != nullptr) { + Tensor sizes; + float* size_data = sizes.mutable_data(); + float* outsize_data = out_size->mutable_data(TARGET(kCUDA)); + cudaMemcpy( + size_data, outsize_data, sizeof(float) * 2, cudaMemcpyDeviceToHost); + out_h = static_cast(size_data[0]); + out_w = static_cast(size_data[1]); + } + + auto output_data = output->mutable_data(TARGET(kCUDA)); + + if (in_h == out_h && in_w == out_w) { + cudaMemcpy(output_data, + input_data, + sizeof(float) * n * c * in_h * in_w, + cudaMemcpyHostToDevice); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + int pixelNum = n * out_chw; + int threads = 512; + int blocks = (pixelNum + threads - 1) / threads; + blocks = blocks > 8 ? 8 : blocks; + + KeNearestNeighborInterp<<>>(input_data, + in_h, + in_w, + n, + in_chw, + output_data, + out_h, + out_w, + n, + out_chw, + c, + ratio_h, + ratio_w, + align_corners); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(nearest_interp, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::NearestInterpCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("OutSize", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/nearest_interp_compute.h b/lite/kernels/cuda/nearest_interp_compute.h new file mode 100644 index 0000000000..d4fb0f43c6 --- /dev/null +++ b/lite/kernels/cuda/nearest_interp_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class NearestInterpCompute + : public KernelLite { + public: + using param_t = operators::InterpolateParam; + + void Run() override; + virtual ~NearestInterpCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/nearest_interp_compute_test.cc b/lite/kernels/cuda/nearest_interp_compute_test.cc new file mode 100644 index 0000000000..4aec6db1a2 --- /dev/null +++ b/lite/kernels/cuda/nearest_interp_compute_test.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/nearest_interp_compute.h" +#include +#include +#include +#include "lite/fluid/eigen.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +using EigenTensor = lite::fluid::EigenTensor; +using Tensor = lite::Tensor; + +static void NearestNeighborInterpolate(const Tensor& input, + Tensor* output, + const float ratio_h, + const float ratio_w, + const int n, + const int c, + const int out_h, + const int out_w, + const bool align_corners) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); + for (int l = 0; l < out_w; l++) { + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + output_t(i, j, k, l) = input_t(i, j, in_k, in_l); + } + } + } + } +} + +static void NearestInterpRef(operators::InterpolateParam param, + Tensor* input, + const size_t scale, + const size_t n, + const size_t c, + const size_t in_h, + const size_t in_w, + Tensor* output_size, + Tensor* output, + size_t out_h, + size_t out_w) { + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + bool align_corners = param.align_corners; + if (output_size != nullptr) { + auto out_size_data = output_size->mutable_data(); + out_h = static_cast(out_size_data[0]); + out_w = static_cast(out_size_data[1]); + } + + float* input_data = input->mutable_data(); + LOG(INFO) << *(input_data + 2); + float* output_data = output->mutable_data(); + LOG(INFO) << *(output_data + 2); + if (in_h == out_h && in_w == out_w) { + std::memcpy(output_data, input_data, sizeof(float) * n * c * in_h * in_w); + LOG(INFO) << *(output_data + 2); + return; + } + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + NearestNeighborInterpolate( + *input, output, ratio_h, ratio_w, n, c, out_h, out_w, align_corners); +} + +TEST(nearest_interp, normal) { + NearestInterpCompute nearest_interp_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::InterpolateParam param; + + Tensor x, osz, out; + Tensor x_cpu, osz_cpu, out_cpu; + Tensor x_ref, osz_ref, out_ref; + + int n = 1, c = 3, in_h = 4, in_w = 4; + int in_chw = c * in_h * in_w; + int out_h = 4, out_w = 4; + float scale = 2.0; + + param.out_h = out_h; + param.out_w = out_w; + param.scale = scale; + param.align_corners = false; + + x.Resize({n, c, in_h, in_w}); + osz.Resize({2}); + out.Resize({n, c, out_h, out_w}); + + x_cpu.Resize({n, c, in_h, in_w}); + osz_cpu.Resize({2}); + out_cpu.Resize({n, c, out_h, out_w}); + + x_ref.Resize({n, c, in_h, in_w}); + osz_ref.Resize({2}); + out_ref.Resize({n, c, out_h, out_w}); + + auto* x_data = x.mutable_data(TARGET(kCUDA)); + auto* osz_data = osz.mutable_data(TARGET(kCUDA)); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + float* x_cpu_data = x_cpu.mutable_data(); + float* osz_cpu_data = osz_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + float* x_ref_data = x_ref.mutable_data(); + float* osz_ref_data = osz_ref.mutable_data(); + float* out_ref_data = out_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + osz_cpu_data[0] = out_h; + osz_cpu_data[1] = out_w; + osz_ref_data[0] = out_h; + osz_ref_data[1] = out_w; + + x.Assign(x_cpu_data, x_cpu.dims()); + osz.Assign(osz_cpu_data, osz_cpu.dims()); + + param.X = &x; + param.OutSize = &osz; + param.Out = &out; + nearest_interp_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + nearest_interp_kernel.SetContext(std::move(ctx)); + nearest_interp_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + NearestInterpRef( + param, &x_ref, scale, n, c, in_h, in_w, &osz_ref, &out_ref, out_h, out_w); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle -- GitLab