From 469a349ae3f8d08dec28b88176769508b3d66d92 Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Mon, 23 Apr 2018 17:26:19 -0700 Subject: [PATCH] polishing after qingqing's comments --- paddle/fluid/operators/bilinear_interp_op.cc | 11 +- paddle/fluid/operators/bilinear_interp_op.cu | 81 +++++++++++++- .../fluid/operators/bilinear_interp_op.cu.h | 101 ------------------ .../unittests/test_bilinear_interp_op.py | 14 +++ 4 files changed, 99 insertions(+), 108 deletions(-) delete mode 100644 paddle/fluid/operators/bilinear_interp_op.cu.h diff --git a/paddle/fluid/operators/bilinear_interp_op.cc b/paddle/fluid/operators/bilinear_interp_op.cc index a4c3a2a132..69f79bf93b 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cc +++ b/paddle/fluid/operators/bilinear_interp_op.cc @@ -44,10 +44,13 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "The input tensor of bilinear interpolation, 4-D with NCHW shape"); - AddOutput("Out", "The output tensor with the same shape as X"); - AddAttr("out_h", "output height of bilinear interpolation op."); - AddAttr("out_w", "output weight of bilinear interpolation op."); + "(Tensor) The input tensor of bilinear interpolation, " + "This is a 4-D tensor with shape of (N x C x h x w)"); + AddOutput("Out", + "(Tensor) The dimension of output is (N x C x out_h x out_w]"); + + AddAttr("out_h", "(int) output height of bilinear interpolation op."); + AddAttr("out_w", "(int) output width of bilinear interpolation op."); AddComment(R"DOC( Bilinear interpolation is an extension of linear interpolation for interpolating functions of two variables (e.g. H-direction and diff --git a/paddle/fluid/operators/bilinear_interp_op.cu b/paddle/fluid/operators/bilinear_interp_op.cu index 0dbbe29b2d..82eb9e83bd 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cu +++ b/paddle/fluid/operators/bilinear_interp_op.cu @@ -9,15 +9,90 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/bilinear_interp_op.cu.h" -#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/bilinear_interp_op.h" +#include "paddle/fluid/platform/cuda_helper.h" namespace paddle { namespace operators { using framework::Tensor; +template +__global__ void KeBilinearInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const T ratio_h, const T ratioW) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratioW * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratioW * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id]); + } +} + +template +__global__ void KeBilinearInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const T ratio_h, const T ratioW) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < nthreads) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; + + int out_img_idy = (out_id_w % out_img_size) / out_img_w; + int in_img_idy = ratio_h * out_img_idy; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = ratioW * out_img_idx; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T w1lambda = ratioW * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; + + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + const T* out_pos = &out[out_id_h * output_w + out_id_w]; + atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); + atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); + atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]); + atomicAdd(&in_pos[h_id * in_img_w + w_id], + h1lambda * w1lambda * out_pos[0]); + } +} + template class BilinearInterpOpCUDAKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/bilinear_interp_op.cu.h b/paddle/fluid/operators/bilinear_interp_op.cu.h deleted file mode 100644 index 0eb568d80e..0000000000 --- a/paddle/fluid/operators/bilinear_interp_op.cu.h +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/platform/cuda_helper.h" - -namespace paddle { -namespace operators { - -template -__global__ void KeBilinearInterpFw(const T* in, const size_t inImgH, - const size_t inImgW, const size_t inputH, - const size_t inputW, T* out, - const size_t outImgH, const size_t outImgW, - const size_t outputH, const size_t outputW, - const size_t numChannels, const T ratioH, - const T ratioW) { - int nthreads = outputH * outputW; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int outIdH = tid / outputW; - int outIdW = tid % outputW; - int inImgSize = inputW / numChannels; - int outImgSize = outputW / numChannels; - int channelId = outIdW / outImgSize; - - int outImgIdy = (outIdW % outImgSize) / outImgW; - int inImgIdy = ratioH * outImgIdy; - int hId = (inImgIdy < inImgH - 1) ? 1 : 0; - T h1lambda = ratioH * outImgIdy - inImgIdy; - T h2lambda = 1.f - h1lambda; - - int outImgIdx = tid % outImgW; - int inImgIdx = ratioW * outImgIdx; - int wId = (inImgIdx < inImgW - 1) ? 1 : 0; - T w1lambda = ratioW * outImgIdx - inImgIdx; - T w2lambda = 1.f - w1lambda; - - const T* inPos = &in[outIdH * inputW + channelId * inImgSize + - inImgIdy * inImgW + inImgIdx]; - - // bilinear interpolation - out[outIdH * outputW + outIdW] = - h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) + - h1lambda * (w2lambda * inPos[hId * inImgW] + - w1lambda * inPos[hId * inImgW + wId]); - } -} - -template -__global__ void KeBilinearInterpBw(T* in, const size_t inImgH, - const size_t inImgW, const size_t inputH, - const size_t inputW, const T* out, - const size_t outImgH, const size_t outImgW, - const size_t outputH, const size_t outputW, - const size_t numChannels, const T ratioH, - const T ratioW) { - int nthreads = outputH * outputW; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < nthreads) { - int outIdH = tid / outputW; - int outIdW = tid % outputW; - int inImgSize = inputW / numChannels; - int outImgSize = outputW / numChannels; - int channelId = outIdW / outImgSize; - - int outImgIdy = (outIdW % outImgSize) / outImgW; - int inImgIdy = ratioH * outImgIdy; - int hId = (inImgIdy < inImgH - 1) ? 1 : 0; - T h1lambda = ratioH * outImgIdy - inImgIdy; - T h2lambda = 1.f - h1lambda; - - int outImgIdx = tid % outImgW; - int inImgIdx = ratioW * outImgIdx; - int wId = (inImgIdx < inImgW - 1) ? 1 : 0; - T w1lambda = ratioW * outImgIdx - inImgIdx; - T w2lambda = 1.f - w1lambda; - - T* inPos = &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + - inImgIdx]; - const T* outPos = &out[outIdH * outputW + outIdW]; - atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]); - atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]); - atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]); - atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]); - } -} - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index b5ec3942e8..4af5f524a5 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -84,5 +84,19 @@ class TestCase2(TestBilinearInterpOp): self.out_w = 12 +class TestCase2(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [16, 3, 512, 1024] + self.out_h = 128 + self.out_w = 256 + + +class TestCase2(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [8, 1, 256, 128] + self.out_h = 1024 + self.out_w = 1024 + + if __name__ == "__main__": unittest.main() -- GitLab