attn_bias_add.cu.h 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

25
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
26 27 28 29 30 31 32

#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif

33
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
34 35
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
36
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
#include "paddle/fluid/platform/fast_divmod.h"

namespace paddle {
namespace operators {

#define MAX_INPUT_NUM 2

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;

template <typename InT, typename OutT, int ShapeSize, int VecSize,
          int DATA_PER_THREAD, typename Functor>
__global__ void BroadcastKernelBinary(
    const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
53 54
    phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
    phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
55 56 57 58 59 60 61 62 63 64 65 66 67 68
        configlists,
    int main_tid, int tail_tid, Functor func) {
  int fix = blockIdx.x * blockDim.x * VecSize;
  int num = tail_tid;
  InT arg0[VecSize * DATA_PER_THREAD];
  InT arg1[VecSize * DATA_PER_THREAD];
  OutT result[VecSize * DATA_PER_THREAD];
  if (blockIdx.x < main_tid) {
    num = blockDim.x * VecSize;  // blockIdx.x < main_tid
  }

  // load in0
  if (use_broadcast[0]) {
    kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
69
        arg0, in0, fix, configlists[0], numel);
70 71 72 73 74 75
  } else {
    kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
  }
  // load in1
  if (use_broadcast[1]) {
    kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
76
        arg1, in1, fix, configlists[1], numel);
77 78 79 80 81 82 83
  } else {
    kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
  }
  // compute
  kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
      result, arg0, arg1, func);
  // store
84 85
  kernel_primitives::WriteData<OutT, VecSize, 1, 1, true>(out + fix, result,
                                                          num);
86 87 88 89 90 91
}

// bias add forward impl for "[m, n] + [n] = [m, n]"
template <typename T>
void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
                           const T* in0, const T* in1, T* out) {
92 93 94
  int in_vec_size =
      std::min(phi::GetVectorizedSize<T>(in0), phi::GetVectorizedSize<T>(in1));
  int out_vec_size = std::min(4, phi::GetVectorizedSize<T>(out));
95 96 97 98 99 100 101 102 103 104 105 106
  int vec_size = std::min(out_vec_size, in_vec_size);

  int numel = m * n;
  const int threads = 256;
  const int data_per_thread = 1;
  int blocks =
      ((numel + vec_size * data_per_thread - 1) / (vec_size * data_per_thread) +
       threads - 1) /
      threads;
  int main_tid = numel / (data_per_thread * vec_size * threads);
  int tail_tid = numel % (data_per_thread * vec_size * threads);

107 108
  phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists;
  phi::Array<bool, MAX_INPUT_NUM> use_broadcast;
109 110 111 112 113 114 115 116 117 118 119

  use_broadcast[0] = false;
  use_broadcast[1] = false;
  if (m != 1) {
    use_broadcast[1] = true;
  }
  // Here, dims are transposed due to the logic in BroadcastConfig.
  std::vector<int64_t> input1_dims = {n, 1};
  std::vector<int64_t> out_dims = {n, m};
  configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);

120
  auto func = AddFunctor<T>();
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
  auto stream = ctx.stream();
  switch (vec_size) {
    case 4: {
      BroadcastKernelBinary<T, T, 2, 4,
                            data_per_thread><<<blocks, threads, 0, stream>>>(
          in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid,
          func);
      break;
    }
    case 2: {
      BroadcastKernelBinary<T, T, 2, 2,
                            data_per_thread><<<blocks, threads, 0, stream>>>(
          in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid,
          func);
      break;
    }
    case 1: {
      BroadcastKernelBinary<T, T, 2, 1,
                            data_per_thread><<<blocks, threads, 0, stream>>>(
          in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid,
          func);
      break;
    }
    default: {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
    }
  }
}

template <typename T, int BlockDim>
__global__ void LAUNCH_BOUNDS(BlockDim)
    Compute1DColumnReduceKernel(const int reduce_num, const int left_num,
                                const T* in, T* out) {
  typedef cub::BlockReduce<ReduceParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage mean_storage;

  for (int i = blockIdx.x; i < left_num; i += gridDim.x) {
    ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
    for (int j = threadIdx.x; j < reduce_num; j += blockDim.x) {
      const int index = j * left_num + i;
      ReduceParamType<T> x_i = static_cast<ReduceParamType<T>>(in[index]);
      x_sum += x_i;
    }
    x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
    if (threadIdx.x == 0) {
      out[i] = static_cast<T>(x_sum);
    }
  }
}

template <typename T>
void Launch1DColumnReduce(gpuStream_t stream, const int max_threads,
                          const int reduce_num, const int left_num,
                          const T* d_out, T* d_bias) {
  const int block = 256;
  const int max_blocks = std::max(max_threads / block, 1);
  const int grid = std::min(left_num, max_blocks);
  Compute1DColumnReduceKernel<T, block><<<grid, block, 0, stream>>>(
      reduce_num, left_num, d_out, d_bias);
}

void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
                              const int left_num, int* blocking_size,
                              bool* should_reduce_again, dim3* block_dim,
                              dim3* grid_dim) {
  block_dim->z = 1;
  grid_dim->z = 1;
  *should_reduce_again = false;

  int num_block = (max_threads / left_num);
  if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
194
    *blocking_size = phi::funcs::details::GetLastPow2(reduce_num / num_block);
195
    if (*blocking_size <= 1) {
196
      *blocking_size = phi::funcs::details::GetLastPow2(sqrt(reduce_num));
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    } else if (*blocking_size * 2 < reduce_num) {
      *blocking_size *= 2;
    }
    *should_reduce_again = true;
    block_dim->x = 32;
    block_dim->y = 1;
    grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x;
    grid_dim->y = (reduce_num + *blocking_size - 1) / *blocking_size;
  } else {
    block_dim->x = 32;
    *blocking_size = reduce_num;
    grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x;
    grid_dim->y = 1;
  }
}

template <typename T>
__global__ void BiasAddBwSinglePassKernel(const T* in, int reduce_num,
                                          int left_num, T* out) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
  if (idx < left_num) {
    for (int iy = 0; iy < reduce_num; iy++) {
      int id = iy * left_num + idx;
      ReduceParamType<T> x_val = static_cast<ReduceParamType<T>>(in[id]);
      x_sum += x_val;
    }
    out[idx] = static_cast<T>(x_sum);
  }
}

template <typename T>
__global__ void BiasAddBw2DReduceKernel(const T* x, int reduce_num,
                                        int left_num, int workload_per_thread,
                                        ReduceParamType<T>* temp_x_sum) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int idy = blockIdx.y * workload_per_thread;

  T x_val;
  ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
  if (idx < left_num) {
    int loop = reduce_num - idy;
    loop = loop > workload_per_thread ? workload_per_thread : loop;
    for (int iy = 0; iy < loop; iy++) {
      int id = (idy + iy) * left_num + idx;
      ReduceParamType<T> x_val = static_cast<ReduceParamType<T>>(x[id]);
      x_sum += x_val;
    }
    temp_x_sum[idx + blockIdx.y * left_num] = x_sum;
  }
}

template <typename T>
__global__ void BiasAddBw1DReduceKernel(const ReduceParamType<T>* temp_sum,
                                        int workload_per_thread, int left_num,
                                        T* out) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
  if (idx < left_num) {
    for (int iy = 0; iy < workload_per_thread; iy++) {
      int id = iy * left_num + idx;
      x_sum += temp_sum[id];
    }
    out[idx] = static_cast<T>(x_sum);
  }
}

template <typename T>
265 266 267
void Launch2DColumnReduce(const platform::CUDADeviceContext& dev_ctx,
                          const int max_threads, const int reduce_num,
                          const int left_num, const T* d_out, T* d_bias) {
268 269 270 271 272 273
  dim3 block;
  dim3 grid;
  bool should_reduce_again = false;
  int blocking_size = 1;
  SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size,
                           &should_reduce_again, &block, &grid);
274
  const auto& stream = dev_ctx.stream();
275 276 277 278 279 280

  if (!should_reduce_again) {
    BiasAddBwSinglePassKernel<T><<<grid, block, 0, stream>>>(d_out, reduce_num,
                                                             left_num, d_bias);
  } else {
    framework::Tensor tmp_sum;
281 282
    tmp_sum.Resize({grid.y, left_num});
    tmp_sum.mutable_data<ReduceParamType<T>>(dev_ctx.GetPlace());
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307

    BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>(
        d_out, reduce_num, left_num, blocking_size,
        tmp_sum.template data<ReduceParamType<T>>());

    BiasAddBw1DReduceKernel<T><<<grid.x, block.x, 0, stream>>>(
        tmp_sum.template data<ReduceParamType<T>>(), grid.y, left_num, d_bias);
  }
}

// bias add backward impl whose pattern are column-reduce with d_out[m, n] as
// input
// and d_bias[n] as output.
template <typename T>
void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m,
                           int n, const T* d_out, T* d_bias) {
  int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
  int reduce_num = m;
  int left_num = n;
  bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
                         (left_num > REDUCE_SPLIT_BOUNDARY);
  if (!is_large_enough) {
    Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
                         d_out, d_bias);
  } else {
308 309
    Launch2DColumnReduce(dev_ctx, max_threads, reduce_num, left_num, d_out,
                         d_bias);
310 311 312 313 314 315 316
  }
}

#undef MAX_INPUT_NUM

}  // namespace operators
}  // namespace paddle