diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 46c0a10ff7885bb05e92cbfb9c5eea1ac2f8ae53..348a55db117245582a8f13c5abf9161a8c880940 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -16,6 +16,7 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute. add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps}) add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose) add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) @@ -26,3 +27,4 @@ nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpos nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) +nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) diff --git a/lite/kernels/cuda/bilinear_interp_compute.cu b/lite/kernels/cuda/bilinear_interp_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..7e1dbaf228c31d8123e48832e93e0180c4920359 --- /dev/null +++ b/lite/kernels/cuda/bilinear_interp_compute.cu @@ -0,0 +1,195 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/bilinear_interp_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +template +__global__ void BilinearInterp(const T* in, + const size_t in_img_h, + const size_t in_img_w, + const size_t input_h, + const size_t input_w, + T* out, + const size_t out_img_h, + const size_t out_img_w, + const size_t output_h, + const size_t output_w, + const size_t num_channels, + const float ratio_h, + const float ratio_w, + const bool align_corners, + const int align_mode) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id = out_id_w / out_img_size; + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int out_img_idx = tid % out_img_w; + + int in_img_idy = align_flag + ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) + : static_cast(ratio_h * out_img_idy); + in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = (src_h > 0) ? src_h : 0; + T h1lambda = + align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int in_img_idx = align_flag + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id]); + } +} + +void BilinearInterpCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + Tensor* input = param.X; + Tensor* output = param.Out; + Tensor* out_size = param.OutSize; + + auto* input_data = input->data(); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + int out_h = param.out_h; + int out_w = param.out_w; + float scale = param.scale; + bool align_corners = param.align_corners; + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + if (out_size != nullptr) { + Tensor sizes; + float* size_data = sizes.mutable_data(); + float* outsize_data = out_size->mutable_data(TARGET(kCUDA)); + cudaMemcpy( + size_data, outsize_data, sizeof(float) * 2, cudaMemcpyDeviceToHost); + out_h = static_cast(size_data[0]); + out_w = static_cast(size_data[1]); + } + + auto output_data = output->mutable_data(TARGET(kCUDA)); + + if (in_h == out_h && in_w == out_w) { + cudaMemcpy(output_data, + input_data, + sizeof(float) * n * c * in_h * in_w, + cudaMemcpyHostToDevice); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + int pixel_num = n * out_chw; + int threads = 512; + int blocks = (pixel_num + threads - 1) / threads; + blocks = blocks > 8 ? 8 : blocks; + int align_mode = param.align_mode; + + BilinearInterp<<>>(input_data, + in_h, + in_w, + n, + in_chw, + output_data, + out_h, + out_w, + n, + out_chw, + c, + ratio_h, + ratio_w, + align_corners, + align_mode); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(bilinear_interp, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::BilinearInterpCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("OutSize", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/bilinear_interp_compute.h b/lite/kernels/cuda/bilinear_interp_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..333e67f8fff373a84ac9f3a19fc57214376bd34f --- /dev/null +++ b/lite/kernels/cuda/bilinear_interp_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class BilinearInterpCompute + : public KernelLite { + public: + using param_t = operators::InterpolateParam; + + void Run() override; + virtual ~BilinearInterpCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/bilinear_interp_compute_test.cc b/lite/kernels/cuda/bilinear_interp_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7e8143150d2963fb4cb74c3530cfd6e125a454c --- /dev/null +++ b/lite/kernels/cuda/bilinear_interp_compute_test.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/bilinear_interp_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; + +TEST(bilinear_interp, normal) { + BilinearInterpCompute bilinear_interp_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::InterpolateParam param; + + Tensor x, osz, out; + Tensor x_cpu, osz_cpu, out_cpu; + Tensor x_ref, osz_ref, out_ref; + + int n = 1, c = 1, in_h = 3, in_w = 3; + int out_h = 6, out_w = 6; + float scale = 2.0; + + param.out_h = out_h; + param.out_w = out_w; + param.scale = scale; + param.align_corners = false; + param.align_mode = 0; + + x.Resize({n, c, in_h, in_w}); + osz.Resize({2}); + out.Resize({n, c, out_h, out_w}); + + x_cpu.Resize({n, c, in_h, in_w}); + osz_cpu.Resize({2}); + out_cpu.Resize({n, c, out_h, out_w}); + + x_ref.Resize({n, c, in_h, in_w}); + osz_ref.Resize({2}); + out_ref.Resize({n, c, out_h, out_w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + float* x_cpu_data = x_cpu.mutable_data(); + float* osz_cpu_data = osz_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + float* x_ref_data = x_ref.mutable_data(); + float* osz_ref_data = osz_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + osz_cpu_data[0] = out_h; + osz_cpu_data[1] = out_w; + osz_ref_data[0] = out_h; + osz_ref_data[1] = out_w; + + x.Assign(x_cpu_data, x_cpu.dims()); + osz.Assign(osz_cpu_data, osz_cpu.dims()); + + param.X = &x; + param.OutSize = &osz; + param.Out = &out; + bilinear_interp_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + bilinear_interp_kernel.SetContext(std::move(ctx)); + bilinear_interp_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out.numel(); i++) { + LOG(INFO) << out_cpu_data[i]; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/interpolate_op.cc b/lite/operators/interpolate_op.cc index f29acf70a75a7ac6464d8df5da145e760fb1faa3..b98240ba4f255377c0ac661950a45bef0a7d0516 100644 --- a/lite/operators/interpolate_op.cc +++ b/lite/operators/interpolate_op.cc @@ -88,6 +88,9 @@ bool InterpolateOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { if (op_desc.HasAttr("out_h")) { param_.out_h = op_desc.GetAttr("out_h"); } + if (op_desc.HasAttr("align_mode")) { + param_.align_mode = op_desc.GetAttr("align_mode"); + } param_.align_corners = op_desc.GetAttr("align_corners"); param_.interp_method = op_desc.GetAttr("interp_method"); return true; diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 0a957624d6b120eb2513e640d1320f4f0b4e47a1..5ac43e58c9138de18e7e91049aed488af75a2017 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -97,6 +97,7 @@ struct InterpolateParam { int out_h{-1}; int out_w{-1}; bool align_corners{true}; + int align_mode{1}; std::string interp_method{"Nearest"}; };