From ad3b3d9dc111c1dced4632bb2f5e7d137be274a8 Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Wed, 21 Mar 2018 17:45:18 -0700 Subject: [PATCH] ported old paddle gpu bilinear_interp --- paddle/fluid/operators/bilinear_interp_op.cu | 24 ++-- .../fluid/operators/bilinear_interp_op.cu.h | 105 ++++++++++++++++++ 2 files changed, 121 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/operators/bilinear_interp_op.cu.h diff --git a/paddle/fluid/operators/bilinear_interp_op.cu b/paddle/fluid/operators/bilinear_interp_op.cu index 187ad60f2d..c4abdbd3b5 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cu +++ b/paddle/fluid/operators/bilinear_interp_op.cu @@ -9,7 +9,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "hl_cnn.h" +#include "paddle/fluid/operators/bilinear_interp_op.cu.h" #include "paddle/fluid/operators/bilinear_interp_op.h" namespace paddle { @@ -44,9 +44,13 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel { if (in_h == out_h && in_w == out_w) { memcpy(output, input, input_t->numel() * sizeof(T)); } else { - hl_bilinear_forward(input, in_h, in_w, batch_size, in_chw, output, out_h, - out_w, batch_size, out_chw, channels, ratio_h, - ratio_w); + int threadNum = batch_size * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + KeBilinearInterpFw< + T><<>>( + input, in_h, in_w, batch_size, in_chw, output, out_h, out_w, + batch_size, out_chw, channels, ratio_h, ratio_w); } } }; @@ -78,9 +82,13 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel { if (in_h == out_h && in_w == out_w) { memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); } else { - hl_bilinear_backward(d_input, in_h, in_w, batch_size, in_chw, d_output, - out_h, out_w, batch_size, out_chw, channels, ratio_h, - ratio_w); + int threadNum = batch_size * out_chw; + int blocks = (threadNum + 1024 - 1) / 1024; + + KeBilinearInterpBw< + T><<>>( + d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w, + batch_size, out_chw, channels, ratio_h, ratio_w); } } }; @@ -92,4 +100,4 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(bilinear_interp, ops::BilinearInterpOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad, - ops::BilinearInterpGradOpCUDAKernel); \ No newline at end of file + ops::BilinearInterpGradOpCUDAKernel); diff --git a/paddle/fluid/operators/bilinear_interp_op.cu.h b/paddle/fluid/operators/bilinear_interp_op.cu.h new file mode 100644 index 0000000000..ea9a19bf3f --- /dev/null +++ b/paddle/fluid/operators/bilinear_interp_op.cu.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void KeBilinearInterpFw(const T* in, const size_t inImgH, + const size_t inImgW, const size_t inputH, + const size_t inputW, T* out, + const size_t outImgH, const size_t outImgW, + const size_t outputH, const size_t outputW, + const size_t numChannels, const T ratioH, + const T ratioW) { + int nthreads = outputH * outputW; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int outIdH = tid / outputW; + int outIdW = tid % outputW; + int inImgSize = inputW / numChannels; + int outImgSize = outputW / numChannels; + int channelId = outIdW / outImgSize; + + int outImgIdy = (outIdW % outImgSize) / outImgW; + int inImgIdy = ratioH * outImgIdy; + int hId = (inImgIdy < inImgH - 1) ? 1 : 0; + T h1lambda = ratioH * outImgIdy - inImgIdy; + T h2lambda = 1.f - h1lambda; + + int outImgIdx = tid % outImgW; + int inImgIdx = ratioW * outImgIdx; + int wId = (inImgIdx < inImgW - 1) ? 1 : 0; + T w1lambda = ratioW * outImgIdx - inImgIdx; + T w2lambda = 1.f - w1lambda; + + const T* inPos = &in[outIdH * inputW + channelId * inImgSize + + inImgIdy * inImgW + inImgIdx]; + + // bilinear interpolation + out[outIdH * outputW + outIdW] = + h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) + + h1lambda * (w2lambda * inPos[hId * inImgW] + + w1lambda * inPos[hId * inImgW + wId]); + } +} + +template +__global__ void KeBilinearInterpBw(T* in, const size_t inImgH, + const size_t inImgW, const size_t inputH, + const size_t inputW, const T* out, + const size_t outImgH, const size_t outImgW, + const size_t outputH, const size_t outputW, + const size_t numChannels, const T ratioH, + const T ratioW) { + int nthreads = outputH * outputW; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int outIdH = tid / outputW; + int outIdW = tid % outputW; + int inImgSize = inputW / numChannels; + int outImgSize = outputW / numChannels; + int channelId = outIdW / outImgSize; + + int outImgIdy = (outIdW % outImgSize) / outImgW; + int inImgIdy = ratioH * outImgIdy; + int hId = (inImgIdy < inImgH - 1) ? 1 : 0; + T h1lambda = ratioH * outImgIdy - inImgIdy; + T h2lambda = 1.f - h1lambda; + + int outImgIdx = tid % outImgW; + int inImgIdx = ratioW * outImgIdx; + int wId = (inImgIdx < inImgW - 1) ? 1 : 0; + T w1lambda = ratioW * outImgIdx - inImgIdx; + T w2lambda = 1.f - w1lambda; + + T* inPos = &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + + inImgIdx]; + const T* outPos = &out[outIdH * outputW + outIdW]; + atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]); + atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]); + atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]); + atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]); + } +} + +} // namespace operators +} // namespace paddle -- GitLab