bilinear_interp_op.cu 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/bilinear_interp_op.cu.h"
13
#include "paddle/fluid/operators/bilinear_interp_op.h"
W
wangyang59 已提交
14 15
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_helper.h"
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

namespace paddle {
namespace operators {

template <typename T>
class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "This kernel only runs on GPU device.");
    auto* input_t = ctx.Input<Tensor>("X");      // float tensor
    auto* output_t = ctx.Output<Tensor>("Out");  // float tensor
    auto* input = input_t->data<T>();
    auto* output = output_t->mutable_data<T>(ctx.GetPlace());

    int out_h = ctx.Attr<int>("out_h");
    int out_w = ctx.Attr<int>("out_w");
    int batch_size = input_t->dims()[0];
    int channels = input_t->dims()[1];
    int in_h = input_t->dims()[2];
    int in_w = input_t->dims()[3];

    int in_hw = in_h * in_w;
    int out_hw = out_h * out_w;
    int in_chw = channels * in_hw;
    int out_chw = channels * out_hw;

    T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
    T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;

    if (in_h == out_h && in_w == out_w) {
      memcpy(output, input, input_t->numel() * sizeof(T));
    } else {
49 50 51 52 53 54 55
      int threadNum = batch_size * out_chw;
      int blocks = (threadNum + 1024 - 1) / 1024;

      KeBilinearInterpFw<
          T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
          input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
          batch_size, out_chw, channels, ratio_h, ratio_w);
56 57 58 59 60 61 62 63 64 65 66 67 68
    }
  }
};

template <typename T>
class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
    auto* d_output = d_output_t->data<T>();

W
wangyang59 已提交
69 70 71 72 73
    auto& device_ctx =
        ctx.template device_context<platform::CUDADeviceContext>();
    math::SetConstant<platform::CUDADeviceContext, T> zero;
    zero(device_ctx, d_input_t, static_cast<T>(0.0));

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    int out_h = ctx.Attr<int>("out_h");
    int out_w = ctx.Attr<int>("out_w");
    int batch_size = d_input_t->dims()[0];
    int channels = d_input_t->dims()[1];
    int in_h = d_input_t->dims()[2];
    int in_w = d_input_t->dims()[3];

    int in_hw = in_h * in_w;
    int out_hw = out_h * out_w;
    int in_chw = channels * in_hw;
    int out_chw = channels * out_hw;

    T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
    T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;

    if (in_h == out_h && in_w == out_w) {
      memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
    } else {
92 93 94 95 96 97 98
      int threadNum = batch_size * out_chw;
      int blocks = (threadNum + 1024 - 1) / 1024;

      KeBilinearInterpBw<
          T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
          d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
          batch_size, out_chw, channels, ratio_h, ratio_w);
99 100 101 102 103 104 105 106 107 108 109
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
                        ops::BilinearInterpOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
110
                        ops::BilinearInterpGradOpCUDAKernel<float>);