stack_grad_kernel.cu 7.4 KB
Newer Older
C
csy0225 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/stack_grad_kernel.h"
C
csy0225 已提交
16 17 18
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
19
#include "paddle/phi/kernels/funcs/segmented_array.h"
C
csy0225 已提交
20 21 22

namespace phi {

23 24 25 26 27 28 29
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernel(const T* __restrict__ input,
                                  IndexT pre_dim_size,
                                  IndexT split_dim_size,
                                  IndexT suf_dim_size,
                                  IndexT num_split,
                                  ArrayT array) {
C
csy0225 已提交
30 31 32 33 34
  assert(blockDim.y == 1);
  assert(blockDim.z == 1);
  // In this case they are equal
  assert(split_dim_size % num_split == 0);

35 36
  IndexT size = pre_dim_size * split_dim_size * suf_dim_size;
  IndexT each_dim_size = split_dim_size / num_split;
C
csy0225 已提交
37

38
  for (IndexT offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size;
C
csy0225 已提交
39
       offset += blockDim.x * gridDim.x) {
40 41 42
    IndexT i = offset / (split_dim_size * suf_dim_size);
    IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size;
    IndexT k = offset % suf_dim_size;
C
csy0225 已提交
43

44
    T* output = array.data[j / each_dim_size];
C
csy0225 已提交
45 46 47
    if (output == nullptr) {
      return;
    }
48 49
    IndexT output_ind = i * each_dim_size * suf_dim_size +
                        (j % each_dim_size) * suf_dim_size + k;
C
csy0225 已提交
50 51 52 53
    *(output + output_ind) = input[offset];
  }
}

54 55 56 57 58 59
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data,
                                            const IndexT cols,
                                            const IndexT rows,
                                            const IndexT tile_x_num,
                                            ArrayT array) {
60 61 62 63 64 65 66 67 68 69 70 71
  constexpr int buffer_size = 512;
  __shared__ T s_buf[buffer_size];

  for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) {
    IndexT row_idx = tile_x * blockDim.x + threadIdx.x;
    IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y;
    int s_idx = threadIdx.y * blockDim.x + threadIdx.x;
    bool is_valid = (col_idx < cols && row_idx < rows);

    if (is_valid) {
      T data = in_data[row_idx * cols + col_idx];
      s_buf[s_idx] = data;
C
csy0225 已提交
72
    }
73 74
    __syncthreads();
    if (is_valid) {
75 76
      if (array.data[col_idx]) {
        array.data[col_idx][row_idx] = s_buf[s_idx];
77
      }
C
csy0225 已提交
78 79
    }
  }
80 81
}

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
template <typename Context,
          typename T,
          typename IndexT,
          funcs::SegmentedArraySize Size>
void LaunchUnStackKernel(const Context& ctx,
                         const IndexT pre_dim,
                         const IndexT split_dim,
                         const IndexT suf_dim,
                         const IndexT num_splits,
                         const DenseTensor& out_grad,
                         std::vector<DenseTensor*>* x_grad) {
  // each x_grad should have same shape
  auto dout_ptr = out_grad.data<T>();
  funcs::PointerArraySetter<Context, T, Size> setter(ctx, x_grad);

  if (suf_dim == 1) {
    // For the case axis == (out_grad.dims().size() - 1)
99 100 101 102
    constexpr int kThreads = 512;
    constexpr int kWarpSize = 32;
    constexpr int kMaxOut = 16;

103 104 105
    int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
    if (split_dim < kMaxOut) {
      tid_y = split_dim;
106
      tid_x =
107
          std::min(backends::gpu::RoundToNextHighPowOfTwo(pre_dim, kWarpSize),
108 109 110 111
                   kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y));
    } else {
      tid_y = kMaxOut;
      tid_x = kWarpSize;
112
      bid_y = backends::gpu::DivUp<int>(split_dim, kMaxOut);
113
    }
114
    int tile_x_num = backends::gpu::DivUp<int>(pre_dim, tid_x);
115 116 117 118
    bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
    dim3 blocks(tid_x, tid_y, 1);
    dim3 grids(bid_x, bid_y, 1);

119 120 121
    UnStackCudaKernelForLastDim<T, IndexT, decltype(setter.array)>
        <<<grids, blocks, 0, ctx.stream()>>>(
            dout_ptr, split_dim, pre_dim, tile_x_num, setter.array);
122
  } else {
123 124 125 126 127 128 129 130 131
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
        ctx, pre_dim * split_dim * suf_dim);

    UnStackCudaKernel<T, IndexT, decltype(setter.array)>
        <<<config.block_per_grid.x,
           config.thread_per_block.x,
           0,
           ctx.stream()>>>(
            dout_ptr, pre_dim, split_dim, suf_dim, num_splits, setter.array);
132 133 134 135
  }
}

template <typename T, typename Context>
136 137
void StackGradKernel(const Context& ctx,
                     const DenseTensor& out_grad,
138 139
                     int axis,
                     std::vector<DenseTensor*> x_grad) {
140 141 142 143 144 145 146 147 148 149 150
  if (axis < 0) axis += out_grad.dims().size();

  int64_t split_dim = out_grad.dims()[axis];
  PADDLE_ENFORCE_EQ(
      split_dim,
      x_grad.size(),
      phi::errors::InvalidArgument(
          "Output x_grad size should be equal to the split_dim, but"
          " received split_dim is:%d x_grad size is:%d.",
          split_dim,
          x_grad.size()));
151

152 153 154 155
  auto dout_dims = out_grad.dims();
  int64_t dout_pre = 1;
  for (int i = 0; i < axis; ++i) {
    dout_pre *= dout_dims[i];
156
  }
157 158 159 160 161 162 163 164 165 166 167 168 169
  int64_t dout_suf = out_grad.numel() / (split_dim * dout_pre);

  if (out_grad.numel() < std::numeric_limits<int32_t>::max()) {
    switch (funcs::CalcArraySize(split_dim)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchUnStackKernel<Context, T, int32_t, kArraySize>(ctx,
                                                               dout_pre,
                                                               split_dim,
                                                               dout_suf,
                                                               split_dim,
                                                               out_grad,
                                                               &x_grad));
    }
C
csy0225 已提交
170
  } else {
171 172 173 174 175 176 177 178 179 180
    switch (funcs::CalcArraySize(split_dim)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchUnStackKernel<Context, T, int64_t, kArraySize>(ctx,
                                                               dout_pre,
                                                               split_dim,
                                                               dout_suf,
                                                               split_dim,
                                                               out_grad,
                                                               &x_grad));
    }
C
csy0225 已提交
181 182 183 184 185 186 187 188 189 190 191
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(stack_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::StackGradKernel,
                   float,
                   double,
192
                   bool,
C
csy0225 已提交
193 194
                   int64_t,
                   int,
195 196
                   uint8_t,
                   int8_t,
C
csy0225 已提交
197 198
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}