flip_kernel.cu 4.9 KB
Newer Older
Y
Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/flip_kernel.h"
Y
Yang 已提交
16 17 18
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
傅剑寒 已提交
19
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
Y
Yang 已提交
20 21
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
22
#include "paddle/phi/core/utils/array.h"
Y
Yang 已提交
23 24 25

namespace phi {

26
template <typename T, size_t Rank>
傅剑寒 已提交
27
__global__ void flip_cuda_kernel(const int64_t N,
Y
Yang 已提交
28 29
                                 const T* in_data,
                                 T* out_data,
30 31 32 33
                                 phi::Array<int64_t, Rank> shape,
                                 phi::Array<int64_t, Rank> stride,
                                 phi::Array<int, Rank> flip_dims,
                                 int flip_dims_size) {
Y
Yang 已提交
34 35 36 37 38 39
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx >= N) {
    return;
  }

  int cur_indices = idx, rem = 0, dst_offset = 0;
40
  for (int i = 0; i < Rank; ++i) {
Y
Yang 已提交
41
    int64_t temp = cur_indices;
42 43
    cur_indices = cur_indices / stride[i];
    rem = temp - cur_indices * stride[i];
Y
Yang 已提交
44 45 46
    // flip the indices if it is in flip_dims
    for (int j = 0; j < flip_dims_size; ++j) {
      if (i == flip_dims[j]) {
47
        cur_indices = shape[i] - 1 - cur_indices;
Y
Yang 已提交
48 49
      }
    }
50
    dst_offset += cur_indices * stride[i];
Y
Yang 已提交
51 52 53 54 55
    cur_indices = rem;
  }
  out_data[idx] = in_data[dst_offset];
}

56
template <typename T, typename Context, size_t N>
傅剑寒 已提交
57 58 59 60
void LaunchFlipCudaKernel(const Context& dev_ctx,
                          const DenseTensor& x,
                          const std::vector<int>& axis,
                          DenseTensor* out) {
Y
Yang 已提交
61 62 63 64 65
  auto* in_data = x.data<T>();
  auto* out_data = dev_ctx.template Alloc<T>(out);

  auto x_dims = x.dims();
  const int total_dims = x_dims.size();
傅剑寒 已提交
66 67
  const int64_t numel = x.numel();
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
Y
Yang 已提交
68 69
  auto x_stride = phi::stride(x_dims);

70 71 72
  phi::Array<int64_t, N> stride_a;
  phi::Array<int64_t, N> shape_a;
  phi::Array<int, N> flip_dims_a;
傅剑寒 已提交
73 74
  size_t flip_dims_size = axis.size();

75 76 77
  for (size_t idx = 0; idx < N; ++idx) {
    stride_a[idx] = x_stride[idx];
    shape_a[idx] = x_dims[idx];
傅剑寒 已提交
78 79 80 81 82 83 84
    flip_dims_a[idx] = idx < flip_dims_size ? axis[idx] : 0;
  }

  for (size_t i = 0; i < flip_dims_a.size(); ++i) {
    if (flip_dims_a[i] < 0) {
      flip_dims_a[i] += total_dims;
    }
85
  }
傅剑寒 已提交
86 87 88 89 90 91 92 93 94
  flip_cuda_kernel<T, N>
      <<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
          numel,
          in_data,
          out_data,
          shape_a,
          stride_a,
          flip_dims_a,
          flip_dims_size);
95
}
Y
Yang 已提交
96

97 98 99 100 101 102 103
template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
                const DenseTensor& x,
                const std::vector<int>& axis,
                DenseTensor* out) {
  const size_t total_dims = x.dims().size();
  switch (total_dims) {
C
caozhou 已提交
104 105 106
    case 0:
      LaunchFlipCudaKernel<T, Context, 0>(dev_ctx, x, axis, out);
      break;
107
    case 1:
傅剑寒 已提交
108
      LaunchFlipCudaKernel<T, Context, 1>(dev_ctx, x, axis, out);
109 110
      break;
    case 2:
傅剑寒 已提交
111
      LaunchFlipCudaKernel<T, Context, 2>(dev_ctx, x, axis, out);
112 113
      break;
    case 3:
傅剑寒 已提交
114
      LaunchFlipCudaKernel<T, Context, 3>(dev_ctx, x, axis, out);
115 116
      break;
    case 4:
傅剑寒 已提交
117
      LaunchFlipCudaKernel<T, Context, 4>(dev_ctx, x, axis, out);
118 119
      break;
    case 5:
傅剑寒 已提交
120
      LaunchFlipCudaKernel<T, Context, 5>(dev_ctx, x, axis, out);
121 122
      break;
    case 6:
傅剑寒 已提交
123
      LaunchFlipCudaKernel<T, Context, 6>(dev_ctx, x, axis, out);
124 125
      break;
    case 7:
傅剑寒 已提交
126
      LaunchFlipCudaKernel<T, Context, 7>(dev_ctx, x, axis, out);
127 128
      break;
    case 8:
傅剑寒 已提交
129
      LaunchFlipCudaKernel<T, Context, 8>(dev_ctx, x, axis, out);
130 131
      break;
    case 9:
傅剑寒 已提交
132
      LaunchFlipCudaKernel<T, Context, 9>(dev_ctx, x, axis, out);
133 134 135 136 137 138 139
      break;
    default:
      PADDLE_THROW(phi::errors::InvalidArgument(
          "dims of input tensor should be less than 10, But received"
          "%d",
          x.dims().size()));
  }
Y
Yang 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
}
}  // namespace phi

PD_REGISTER_KERNEL(flip,
                   GPU,
                   ALL_LAYOUT,
                   phi::FlipKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   int,
                   int64_t,
                   bool,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}