attn_bias_add.cu.h 12.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

25
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
26 27 28 29 30 31 32

#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif

33
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
34 35
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
36
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
L
limingshu 已提交
37
#include "paddle/phi/kernels/funcs/fast_divmod.h"
38 39 40 41 42 43 44 45 46 47 48

namespace paddle {
namespace operators {

#define MAX_INPUT_NUM 2

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;

49 50 51 52 53 54
template <typename InT,
          typename OutT,
          int ShapeSize,
          int VecSize,
          int DATA_PER_THREAD,
          typename Functor>
55
__global__ void BroadcastKernelBinary(
56 57 58 59 60
    const InT* __restrict__ in0,
    const InT* __restrict__ in1,
    OutT* out,
    phi::Array<bool, MAX_INPUT_NUM> use_broadcast,
    uint32_t numel,
61
    phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
62 63 64
    int main_tid,
    int tail_tid,
    Functor func) {
65 66 67 68 69 70 71 72 73 74 75
  int fix = blockIdx.x * blockDim.x * VecSize;
  int num = tail_tid;
  InT arg0[VecSize * DATA_PER_THREAD];
  InT arg1[VecSize * DATA_PER_THREAD];
  OutT result[VecSize * DATA_PER_THREAD];
  if (blockIdx.x < main_tid) {
    num = blockDim.x * VecSize;  // blockIdx.x < main_tid
  }

  // load in0
  if (use_broadcast[0]) {
76
    kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD>(
77
        arg0, in0, fix, configlists[0], numel);
78 79 80 81 82
  } else {
    kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
  }
  // load in1
  if (use_broadcast[1]) {
83
    kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD>(
84
        arg1, in1, fix, configlists[1], numel);
85
  } else {
86
    kernel_primitives::ReadData<InT, VecSize, 1>(arg1, in1 + fix, num);
87 88
  }
  // compute
89
  kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, Functor>(
90 91
      result, arg0, arg1, func);
  // store
92
  kernel_primitives::WriteData<OutT, VecSize, 1, true>(out + fix, result, num);
93 94 95 96
}

// bias add forward impl for "[m, n] + [n] = [m, n]"
template <typename T>
L
Leo Chen 已提交
97
void LaunchBiasAddFwKernel(const phi::GPUContext& ctx,
98 99 100 101 102
                           int m,
                           int n,
                           const T* in0,
                           const T* in1,
                           T* out) {
103 104 105
  int in_vec_size =
      std::min(phi::GetVectorizedSize<T>(in0), phi::GetVectorizedSize<T>(in1));
  int out_vec_size = std::min(4, phi::GetVectorizedSize<T>(out));
106 107 108 109 110 111 112 113 114 115 116 117
  int vec_size = std::min(out_vec_size, in_vec_size);

  int numel = m * n;
  const int threads = 256;
  const int data_per_thread = 1;
  int blocks =
      ((numel + vec_size * data_per_thread - 1) / (vec_size * data_per_thread) +
       threads - 1) /
      threads;
  int main_tid = numel / (data_per_thread * vec_size * threads);
  int tail_tid = numel % (data_per_thread * vec_size * threads);

118
  phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
119
  phi::Array<bool, MAX_INPUT_NUM> use_broadcast;
120 121 122 123 124 125 126 127 128

  use_broadcast[0] = false;
  use_broadcast[1] = false;
  if (m != 1) {
    use_broadcast[1] = true;
  }
  // Here, dims are transposed due to the logic in BroadcastConfig.
  std::vector<int64_t> input1_dims = {n, 1};
  std::vector<int64_t> out_dims = {n, m};
129
  configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);
130

131
  auto func = AddFunctor<T>();
132 133 134
  auto stream = ctx.stream();
  switch (vec_size) {
    case 4: {
135
      BroadcastKernelBinary<T, T, 2, 4, data_per_thread>
136 137 138 139 140 141 142 143
          <<<blocks, threads, 0, stream>>>(in0,
                                           in1,
                                           out,
                                           use_broadcast,
                                           numel,
                                           configlists,
                                           main_tid,
                                           tail_tid,
144
                                           func);
145 146 147
      break;
    }
    case 2: {
148
      BroadcastKernelBinary<T, T, 2, 2, data_per_thread>
149 150 151 152 153 154 155 156
          <<<blocks, threads, 0, stream>>>(in0,
                                           in1,
                                           out,
                                           use_broadcast,
                                           numel,
                                           configlists,
                                           main_tid,
                                           tail_tid,
157
                                           func);
158 159 160
      break;
    }
    case 1: {
161
      BroadcastKernelBinary<T, T, 2, 1, data_per_thread>
162 163 164 165 166 167 168 169
          <<<blocks, threads, 0, stream>>>(in0,
                                           in1,
                                           out,
                                           use_broadcast,
                                           numel,
                                           configlists,
                                           main_tid,
                                           tail_tid,
170
                                           func);
171 172 173 174 175 176 177 178 179 180 181
      break;
    }
    default: {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
    }
  }
}

template <typename T, int BlockDim>
182 183
__global__ void LAUNCH_BOUNDS(BlockDim) Compute1DColumnReduceKernel(
    const int reduce_num, const int left_num, const T* in, T* out) {
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
  typedef cub::BlockReduce<ReduceParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage mean_storage;

  for (int i = blockIdx.x; i < left_num; i += gridDim.x) {
    ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
    for (int j = threadIdx.x; j < reduce_num; j += blockDim.x) {
      const int index = j * left_num + i;
      ReduceParamType<T> x_i = static_cast<ReduceParamType<T>>(in[index]);
      x_sum += x_i;
    }
    x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
    if (threadIdx.x == 0) {
      out[i] = static_cast<T>(x_sum);
    }
  }
}

template <typename T>
202 203 204 205 206 207
void Launch1DColumnReduce(gpuStream_t stream,
                          const int max_threads,
                          const int reduce_num,
                          const int left_num,
                          const T* d_out,
                          T* d_bias) {
208 209 210
  const int block = 256;
  const int max_blocks = std::max(max_threads / block, 1);
  const int grid = std::min(left_num, max_blocks);
211 212
  Compute1DColumnReduceKernel<T, block>
      <<<grid, block, 0, stream>>>(reduce_num, left_num, d_out, d_bias);
213 214
}

215 216 217 218 219 220
void SetConfigForColumnReduce(const int max_threads,
                              const int reduce_num,
                              const int left_num,
                              int* blocking_size,
                              bool* should_reduce_again,
                              dim3* block_dim,
221 222 223 224 225 226 227
                              dim3* grid_dim) {
  block_dim->z = 1;
  grid_dim->z = 1;
  *should_reduce_again = false;

  int num_block = (max_threads / left_num);
  if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
228
    *blocking_size = phi::funcs::details::GetLastPow2(reduce_num / num_block);
229
    if (*blocking_size <= 1) {
230
      *blocking_size = phi::funcs::details::GetLastPow2(sqrt(reduce_num));
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
    } else if (*blocking_size * 2 < reduce_num) {
      *blocking_size *= 2;
    }
    *should_reduce_again = true;
    block_dim->x = 32;
    block_dim->y = 1;
    grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x;
    grid_dim->y = (reduce_num + *blocking_size - 1) / *blocking_size;
  } else {
    block_dim->x = 32;
    *blocking_size = reduce_num;
    grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x;
    grid_dim->y = 1;
  }
}

template <typename T>
248 249 250 251
__global__ void BiasAddBwSinglePassKernel(const T* in,
                                          int reduce_num,
                                          int left_num,
                                          T* out) {
252 253 254 255 256 257 258 259 260 261 262 263 264
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
  if (idx < left_num) {
    for (int iy = 0; iy < reduce_num; iy++) {
      int id = iy * left_num + idx;
      ReduceParamType<T> x_val = static_cast<ReduceParamType<T>>(in[id]);
      x_sum += x_val;
    }
    out[idx] = static_cast<T>(x_sum);
  }
}

template <typename T>
265 266 267 268
__global__ void BiasAddBw2DReduceKernel(const T* x,
                                        int reduce_num,
                                        int left_num,
                                        int workload_per_thread,
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
                                        ReduceParamType<T>* temp_x_sum) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int idy = blockIdx.y * workload_per_thread;

  T x_val;
  ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
  if (idx < left_num) {
    int loop = reduce_num - idy;
    loop = loop > workload_per_thread ? workload_per_thread : loop;
    for (int iy = 0; iy < loop; iy++) {
      int id = (idy + iy) * left_num + idx;
      ReduceParamType<T> x_val = static_cast<ReduceParamType<T>>(x[id]);
      x_sum += x_val;
    }
    temp_x_sum[idx + blockIdx.y * left_num] = x_sum;
  }
}

template <typename T>
__global__ void BiasAddBw1DReduceKernel(const ReduceParamType<T>* temp_sum,
289 290
                                        int workload_per_thread,
                                        int left_num,
291 292 293 294 295 296 297 298 299 300 301 302 303
                                        T* out) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
  if (idx < left_num) {
    for (int iy = 0; iy < workload_per_thread; iy++) {
      int id = iy * left_num + idx;
      x_sum += temp_sum[id];
    }
    out[idx] = static_cast<T>(x_sum);
  }
}

template <typename T>
L
Leo Chen 已提交
304
void Launch2DColumnReduce(const phi::GPUContext& dev_ctx,
305 306 307 308 309
                          const int max_threads,
                          const int reduce_num,
                          const int left_num,
                          const T* d_out,
                          T* d_bias) {
310 311 312 313
  dim3 block;
  dim3 grid;
  bool should_reduce_again = false;
  int blocking_size = 1;
314 315 316 317 318 319 320
  SetConfigForColumnReduce(max_threads,
                           reduce_num,
                           left_num,
                           &blocking_size,
                           &should_reduce_again,
                           &block,
                           &grid);
321
  const auto& stream = dev_ctx.stream();
322 323

  if (!should_reduce_again) {
324 325
    BiasAddBwSinglePassKernel<T>
        <<<grid, block, 0, stream>>>(d_out, reduce_num, left_num, d_bias);
326
  } else {
327
    phi::DenseTensor tmp_sum;
328
    tmp_sum.Resize({grid.y, left_num});
329 330
    dev_ctx.template Alloc<ReduceParamType<T>>(
        &tmp_sum, tmp_sum.numel() * sizeof(ReduceParamType<T>));
331 332

    BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>(
333 334 335 336
        d_out,
        reduce_num,
        left_num,
        blocking_size,
337 338 339 340 341 342 343 344 345 346 347
        tmp_sum.template data<ReduceParamType<T>>());

    BiasAddBw1DReduceKernel<T><<<grid.x, block.x, 0, stream>>>(
        tmp_sum.template data<ReduceParamType<T>>(), grid.y, left_num, d_bias);
  }
}

// bias add backward impl whose pattern are column-reduce with d_out[m, n] as
// input
// and d_bias[n] as output.
template <typename T>
L
Leo Chen 已提交
348 349
void LaunchBiasAddBwKernel(
    const phi::GPUContext& dev_ctx, int m, int n, const T* d_out, T* d_bias) {
350 351 352 353 354 355
  int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
  int reduce_num = m;
  int left_num = n;
  bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
                         (left_num > REDUCE_SPLIT_BOUNDARY);
  if (!is_large_enough) {
356 357
    Launch1DColumnReduce(
        dev_ctx.stream(), max_threads, reduce_num, left_num, d_out, d_bias);
358
  } else {
359 360
    Launch2DColumnReduce(
        dev_ctx, max_threads, reduce_num, left_num, d_out, d_bias);
361 362 363 364 365 366 367
  }
}

#undef MAX_INPUT_NUM

}  // namespace operators
}  // namespace paddle