crop_op.cu 4.5 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#define EIGEN_USE_GPU
16
#include <stdio.h>
W
wanghaoshuang 已提交
17 18
#include "paddle/operators/crop_op.h"

19 20 21
namespace paddle {
namespace operators {

W
wanghaoshuang 已提交
22
using framework::LoDTensor;
23
using framework::Tensor;
24 25 26 27 28

template <typename T, int D>
__global__ void CropKernel(const int N, const int64_t* out_shape,
                           const int64_t* x_shape, const int* crop_rules,
                           const T* x_data, T* out_data) {
29 30 31 32 33 34
  int64_t pos[D];
  int tmp;
  int64_t x_index;
  for (int out_index = blockIdx.x * blockDim.x + threadIdx.x; out_index < N;
       out_index += blockDim.x * gridDim.x) {
    tmp = out_index;
35
    for (int64_t i = D - 1; i >= 0; --i) {
36 37
      pos[i] = (tmp % out_shape[i]) + crop_rules[i * 2];
      tmp = tmp / out_shape[i];
38 39
    }

40
    x_index = pos[0];
41
    for (size_t i = 1; i < D; ++i) {
42
      x_index = x_index * x_shape[i] + pos[i];
43
    }
44
    out_data[out_index] = x_data[x_index];
45 46 47 48 49
  }
}

template <typename T, int D>
void CropCUDAFunctoin(const framework::ExecutionContext& context) {
50 51
  PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
                 "It must use GPUPlace.");
52 53
  auto* x = context.Input<LoDTensor>("X");
  auto* out = context.Output<LoDTensor>("Out");
54
  auto x_data = x->data<T>();
55
  T* out_data = out->mutable_data<T>(paddle::platform::GPUPlace());
56 57
  auto x_dims = x->dims();
  auto out_dims = out->dims();
58 59 60 61 62 63 64
  int64_t out_count = out->numel();
  Tensor x_shape;
  Tensor out_shape;
  int64_t* x_shape_data =
      x_shape.mutable_data<int64_t>({D}, paddle::platform::CPUPlace());
  int64_t* out_shape_data =
      out_shape.mutable_data<int64_t>({D}, paddle::platform::CPUPlace());
65
  for (int i = 0; i < D; ++i) {
66 67
    x_shape_data[i] = x_dims[i];
    out_shape_data[i] = out_dims[i];
68
  }
69 70 71 72
  Tensor x_shape_gpu;
  Tensor out_shape_gpu;
  x_shape_gpu.CopyFrom<int64_t>(x_shape, paddle::platform::GPUPlace());
  out_shape_gpu.CopyFrom<int64_t>(out_shape, paddle::platform::GPUPlace());
73 74
  auto offsets = context.op().Attr<std::vector<int>>("offsets");
  PADDLE_ENFORCE_EQ(
75
      D, offsets.size(),
76 77
      "Offsets size should be equal to dimension size of input tensor.");

78 79 80 81 82 83
  Tensor crop_rules;
  int* crop_rules_data =
      crop_rules.mutable_data<int>({D * 2}, paddle::platform::CPUPlace());
  for (size_t i = 0; i < D; ++i) {
    crop_rules_data[i * 2] = offsets[i];
    crop_rules_data[i * 2 + 1] = x_dims[i] - out_dims[i] - offsets[i];
84 85
  }

86 87
  Tensor crop_rules_gpu;
  crop_rules_gpu.CopyFrom<int>(crop_rules, paddle::platform::GPUPlace());
88

89 90 91 92
  int n = out_dims[0];
  int d = out_dims[1];
  int block = 512;
  int grid = (n * d + block - 1) / block;
93

94 95 96 97 98
  CropKernel<
      T,
      D><<<grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                               context.device_context())
                               .stream()>>>(
99 100
      out_count, out_shape_gpu.data<int64_t>(), x_shape_gpu.data<int64_t>(),
      crop_rules_gpu.data<int>(), x_data, out_data);
101 102 103 104 105 106
}

template <typename T>
class CropOpCUDAKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
107
    size_t rank = context.Input<LoDTensor>("X")->dims().size();
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    switch (rank) {
      case 1:
        CropCUDAFunctoin<T, 1>(context);
        break;
      case 2:
        CropCUDAFunctoin<T, 2>(context);
        break;
      case 3:
        CropCUDAFunctoin<T, 3>(context);
        break;
      case 4:
        CropCUDAFunctoin<T, 4>(context);
        break;
      case 5:
        CropCUDAFunctoin<T, 5>(context);
        break;
      case 6:
        CropCUDAFunctoin<T, 6>(context);
        break;
      default:
        PADDLE_THROW(
            "CropOp only support tensors with no more than 6 dimensions.");
    }
  }
};

}  // namespace operators
}  // namespace paddle

W
wanghaoshuang 已提交
137
namespace ops = paddle::operators;
138
REGISTER_OP_GPU_KERNEL(crop, ops::CropOpCUDAKernel<float>);
W
wanghaoshuang 已提交
139 140
REGISTER_OP_GPU_KERNEL(crop_grad,
                       ops::CropGradKernel<paddle::platform::GPUPlace, float>);