affine_grid_op.cu 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/affine_grid_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
18
#include "paddle/fluid/platform/cuda_primitives.h"
19
#include "paddle/fluid/platform/gpu_info.h"
20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
  CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}

template <typename T>
struct Linspace<paddle::platform::CUDADeviceContext, T> {
  void operator()(T start, T end, int count, bool align_corners,
                  framework::Tensor* numbers,
                  const framework::ExecutionContext& ctx) {
    T* number_data = numbers->mutable_data<T>({count}, ctx.GetPlace());
    T slice = (end - start) / (T)(count - 1);
    if (!align_corners) {
      slice = (end - start) / (T)count;
      start *= (T)(count - 1) / (T)count;
    }
    auto stream = ctx.cuda_device_context().stream();
    int block = 512;
    int grid = (count + block - 1) / block;
    LinspaceKernel<T><<<grid, block, 0, stream>>>(start, slice, count,
                                                  number_data);
  }
};

template <typename T>
__global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w,
                                   T h_start, T w_start, T h_step, T w_step,
                                   const T* theta,  // N, 2, 3
                                   T* output) {
  CUDA_KERNEL_LOOP(index, count) {
    int w = index % out_w;
    int h = (index / out_w) % out_h;
    int n = index / (out_w * out_h);

    T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
    T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);

    int theta_offset = n * 6;  // 2 * 3;
    // affine from (h_coor, w_coor) to (x, y)
W
whs 已提交
65 66
    output[index * 2] = theta[theta_offset] * w_coor +
                        theta[theta_offset + 1] * h_coor +
67
                        theta[theta_offset + 2];
W
whs 已提交
68 69
    output[index * 2 + 1] = theta[theta_offset + 3] * w_coor +
                            theta[theta_offset + 4] * h_coor +
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
                            theta[theta_offset + 5];
  }
}

template <typename T>
__global__ void affine_grid_grad_kernel(const int count, int n, int out_h,
                                        int out_w, T h_start, T w_start,
                                        T h_step, T w_step,
                                        const T* out_grad,  // N, H, W, 2
                                        T* theta_grad) {    // N, 2, 3
  CUDA_KERNEL_LOOP(index, count) {
    int w = index % out_w;
    int h = (index / out_w) % out_h;
    int n = index / (out_w * out_h);
    T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
    T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);

    int theta_offset = n * 6;  // 2 * 3;
    T out_grad_x = out_grad[index * 2];
W
whs 已提交
89 90
    platform::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * w_coor);
    platform::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * h_coor);
91
    platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x);
92 93

    T out_grad_y = out_grad[index * 2 + 1];
W
whs 已提交
94 95
    platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_y * w_coor);
    platform::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * h_coor);
96
    platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y);
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
  }
}

template <typename T>
class AffineGridOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* theta = ctx.Input<Tensor>("Theta");
    int n = theta->dims()[0];
    auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
    auto align_corners = ctx.Attr<bool>("align_corners");
    int h = 0;
    int w = 0;
    if (size_attr.size() == 0) {
      auto* output_shape = ctx.Input<Tensor>("OutputShape");
      Tensor h_sizes;
      framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
      const int* h_size_data = h_sizes.data<int>();
      h = h_size_data[2];
      w = h_size_data[3];
    } else {
      h = size_attr[2];
      w = size_attr[3];
    }
    auto* output = ctx.Output<Tensor>("Output");
    T* out_data = output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());

    T h_step;
    T w_step;
    T h_start = -1;
    T w_start = -1;
    if (align_corners) {
      h_step = static_cast<T>(2) / static_cast<T>(h - 1);
      w_step = static_cast<T>(2) / static_cast<T>(w - 1);
    } else {
      h_step = static_cast<T>(2) / static_cast<T>(h);
      w_step = static_cast<T>(2) / static_cast<T>(w);

      h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
      w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
    }

    const int count = n * h * w;
    int block = 512;
    int grid = (count + block - 1) / block;
    auto cu_stream = ctx.cuda_device_context().stream();
    affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
        count, n, h, w, h_start, w_start, h_step, w_step,
        theta->data<T>(),  // N, 2, 3
        out_data);
  }
};

template <typename T>
class AffineGridGradOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
    auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
    int n = output_grad->dims()[0];
    auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
    auto align_corners = ctx.Attr<bool>("align_corners");
    int h = 0;
    int w = 0;
    if (size_attr.size() == 0) {
      auto* output_shape = ctx.Input<Tensor>("OutputShape");
      Tensor h_sizes;
      framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
      const int* h_size_data = h_sizes.data<int>();
      h = h_size_data[2];
      w = h_size_data[3];
    } else {
      h = size_attr[2];
      w = size_attr[3];
    }
    T* theta_grad_data = theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
    math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
        ctx.cuda_device_context(), theta_grad, static_cast<T>(0));

    T h_step;
    T w_step;
    T h_start = -1;
    T w_start = -1;
    if (align_corners) {
      h_step = static_cast<T>(2) / static_cast<T>(h - 1);
      w_step = static_cast<T>(2) / static_cast<T>(w - 1);
    } else {
      h_step = static_cast<T>(2) / static_cast<T>(h);
      w_step = static_cast<T>(2) / static_cast<T>(w);

      h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
      w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
    }
    const int count = n * h * w;
    VLOG(3) << "count: " << count << "; h_step: " << h_step
            << "; w_step: " << w_step << "; h_start: " << h_start
            << "; w_start: " << w_start;
    int block = 512;
    int grid = (count + block - 1) / block;
    auto cu_stream = ctx.cuda_device_context().stream();
    affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(
        count, n, h, w, h_start, w_start, h_step, w_step,
        output_grad->data<T>(), theta_grad_data);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(affine_grid, ops::AffineGridOpCUDAKernel<float>,
                        ops::AffineGridOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(affine_grid_grad,
                        ops::AffineGridGradOpCUDAKernel<float>,
                        ops::AffineGridGradOpCUDAKernel<double>);