fused_bias_act_kernel.cu 19.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "glog/logging.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h"

PHI_DECLARE_bool(use_fast_math);

namespace phi {
namespace fusion {

#ifndef PADDLE_WITH_HIP
template <typename T,
          typename Functor,
          int VecSize,
          typename LoadFunc,
          typename StoreFunc>
__global__ void ActFFNGlu(const T *bias,
                          Functor act_functor,
                          const int token_num,
                          const int hid_dim,
                          const int elem_num,
                          LoadFunc load_func,
                          StoreFunc store_func) {
  using LoadT = phi::AlignedVector<T, VecSize>;
  LoadT src_vec1;
  LoadT src_vec2;
  LoadT bias_vec1;
  LoadT bias_vec2;
  const int global_tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = global_tid * VecSize; i < elem_num;
       i += gridDim.x * blockDim.x * VecSize) {
    int bi = i / hid_dim;
    int idx = i % hid_dim;

    load_func.template load<VecSize>(&src_vec1, bi * hid_dim * 2 + idx);
    load_func.template load<VecSize>(&src_vec2,
                                     bi * hid_dim * 2 + idx + hid_dim);

    if (bias) {
      phi::Load<T, VecSize>(&bias[idx], &bias_vec1);
      phi::Load<T, VecSize>(&bias[idx + hid_dim], &bias_vec2);
    }
#pragma unroll
    for (int j = 0; j < VecSize; j++) {
      if (bias) {
        src_vec1[j] += bias_vec1[j];
        src_vec2[j] += bias_vec2[j];
      }
      src_vec1[j] = act_functor(src_vec1[j]);
      src_vec1[j] *= src_vec2[j];
    }
    store_func.template store<VecSize>(src_vec1, bi * hid_dim + idx);
  }
}

template <typename T,
          typename Context,
          typename Functor,
          typename LoadFunc,
          typename StoreFunc,
          typename LoadT = T>
void LaunchActFFNGlu(const Context &dev_ctx,
                     const T *bias,
                     const int token_num,
                     const int hid_dim,
                     LoadFunc load_func,
                     StoreFunc store_func) {
  constexpr int VecSize = 16;
  constexpr int PackSize = VecSize / sizeof(LoadT);
  const int elem_cnt = token_num * hid_dim;
  const int blocksize = 128;
  int grid_size = 1;
  Functor functor;
  switch (hid_dim % PackSize) {
    case 0:
      GetNumBlocks(elem_cnt / PackSize, &grid_size);
      ActFFNGlu<T, Functor, PackSize>
          <<<grid_size, blocksize, 0, dev_ctx.stream()>>>(bias,
                                                          functor,
                                                          token_num,
                                                          hid_dim,
                                                          elem_cnt,
                                                          load_func,
                                                          store_func);
      break;
    default:
      GetNumBlocks(elem_cnt, &grid_size);
      ActFFNGlu<T, Functor, 1><<<grid_size, blocksize, 0, dev_ctx.stream()>>>(
          bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func);
      break;
  }
}

template <typename T,
          typename Functor,
          int VecSize,
          typename LoadFunc,
          typename StoreFunc>
__global__ void BiasAct(const T *bias,
                        Functor act_functor,
                        const int rows,
                        const int cols,
                        const int elem_num,
                        LoadFunc load_func,
                        StoreFunc store_func) {
  using LoadT = phi::AlignedVector<T, VecSize>;
  LoadT src_vec;
  LoadT bias_vec;

// Zero Initialize BiasVec.
#pragma unroll
  for (int unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
    bias_vec[unroll_idx] = 0;
  }

  const int global_tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = global_tid * VecSize; i < elem_num;
       i += gridDim.x * blockDim.x * VecSize) {
    int row_idx = i / cols;
    int col_idx = i % cols;
    int linear_idx = row_idx * cols + col_idx;
    load_func.template load<VecSize>(&src_vec, linear_idx);
    if (bias) {
      phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
    }
#pragma unroll
    for (int j = 0; j < VecSize; j++) {
      if (bias) {
        src_vec[j] += bias_vec[j];
      }
      src_vec[j] = act_functor(src_vec[j]);
    }
    store_func.template store<VecSize>(src_vec, linear_idx);
  }
}

template <typename T,
          typename Context,
          typename Functor,
          typename LoadFunc,
          typename StoreFunc,
          typename LoadT = T>
void LaunchBiasAct(const Context &dev_ctx,
                   const T *bias,
                   const int token_num,
                   const int hid_dim,
                   LoadFunc load_func,
                   StoreFunc store_func) {
  constexpr int VecSize = 16;
  constexpr int PackSize = VecSize / sizeof(LoadT);
  const int elem_cnt = token_num * hid_dim;
  const int blocksize = 128;
  int grid_size = 1;
  Functor functor;
  switch (hid_dim % PackSize) {
    case 0:
      GetNumBlocks(elem_cnt / PackSize, &grid_size);
      BiasAct<T, Functor, PackSize>
          <<<grid_size, blocksize, 0, dev_ctx.stream()>>>(bias,
                                                          functor,
                                                          token_num,
                                                          hid_dim,
                                                          elem_cnt,
                                                          load_func,
                                                          store_func);
      break;
    default:
      GetNumBlocks(elem_cnt, &grid_size);
      BiasAct<T, Functor, 1><<<grid_size, blocksize, 0, dev_ctx.stream()>>>(
          bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func);
      break;
  }
}

template <typename T,
          typename Context,
          typename LoadFunc,
          typename StoreFunc,
          typename LoadT = T>
void ComputeImpl(const Context &dev_ctx,
                 const T *bias_data,
                 const std::string &act_method,
                 int rows,
                 int cols,
                 LoadFunc load_func,
                 StoreFunc store_func) {
  if (act_method == "geglu") {
    // Note(Zhengzekang): For GLU structure, we need divide the cols by 2.
    VLOG(8) << "Doing geglu";
    LaunchActFFNGlu<T, Context, GeluFunctor<T>, LoadFunc, StoreFunc, LoadT>(
        dev_ctx, bias_data, rows, cols / 2, load_func, store_func);
  } else if (act_method == "swiglu") {
    VLOG(8) << "Doing swiglu";
    LaunchActFFNGlu<T,
                    Context,
                    CudaSwishFunctor<T>,
                    LoadFunc,
                    StoreFunc,
                    LoadT>(
        dev_ctx, bias_data, rows, cols / 2, load_func, store_func);
  } else if (act_method == "gelu") {
    if (FLAGS_use_fast_math) {
      VLOG(8) << "Doing Fast GELU";
      LaunchBiasAct<T, Context, FastGeluFunctor<T>, LoadFunc, StoreFunc, LoadT>(
          dev_ctx, bias_data, rows, cols, load_func, store_func);
    } else {
      VLOG(8) << "Doing GELU";
      LaunchBiasAct<T, Context, GeluFunctor<T>, LoadFunc, StoreFunc, LoadT>(
          dev_ctx, bias_data, rows, cols, load_func, store_func);
    }
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Currently Only Support GeGLU, SwiGLU, GeLU"));
  }
}

template <typename T, typename Context>
void DispatchComputeImpl(const Context &dev_ctx,
                         const DenseTensor &x,
                         const DenseTensor *bias,
                         const DenseTensor *dequant_scales,
                         const std::string &act_method,
                         int rows,
                         int cols,
                         const float quant_scale,
                         const int quant_round_type,
                         const float quant_max_bound,
                         const float quant_min_bound,
                         DenseTensor *out) {
  const T *bias_data = bias == nullptr ? nullptr : bias->data<T>();
  if (dequant_scales != nullptr && quant_scale > 0) {
    DequantLoad<T> load_func(
        x.data<int32_t>(), dequant_scales->data<float>(), cols);
    QuantStore<T> store_func(dev_ctx.template Alloc<int8_t>(out),
                             quant_round_type,
                             quant_scale,
                             quant_max_bound,
                             quant_min_bound);
    ComputeImpl<T, Context, DequantLoad<T>, QuantStore<T>, int32_t>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  } else if (dequant_scales == nullptr && quant_scale > 0) {
    Load<T> load_func(x.data<T>());
    QuantStore<T> store_func(dev_ctx.template Alloc<int8_t>(out),
                             quant_round_type,
                             quant_scale,
                             quant_max_bound,
                             quant_min_bound);
    ComputeImpl<T>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  } else if (dequant_scales != nullptr && quant_scale <= 0) {
    DequantLoad<T> load_func(
        x.data<int32_t>(), dequant_scales->data<float>(), cols);
    Store<T> store_func(dev_ctx.template Alloc<T>(out));
    ComputeImpl<T, Context, DequantLoad<T>, Store<T>, int32_t>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  } else {
    Load<T> load_func(x.data<T>());
    Store<T> store_func(dev_ctx.template Alloc<T>(out));
    ComputeImpl<T>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  }
}

template <typename T, typename Context>
void DispatchComputeImpl(const Context &dev_ctx,
                         const DenseTensor &x,
                         const DenseTensor *bias,
                         const DenseTensor *dequant_scales,
                         const DenseTensor *shift,
                         const DenseTensor *smooth,
                         const std::string &act_method,
                         int rows,
                         int cols,
                         const float quant_scale,
                         const int quant_round_type,
                         const float quant_max_bound,
                         const float quant_min_bound,
                         DenseTensor *out) {
  bool use_glu = (act_method == "geglu" || act_method == "swiglu");
  const T *bias_data = bias == nullptr ? nullptr : bias->data<T>();

  if (dequant_scales != nullptr && quant_scale > 0) {
    int8_t *out_data = dev_ctx.template Alloc<int8_t>(out);
    DequantLoad<T> load_func(
        x.data<int32_t>(), dequant_scales->data<float>(), cols);
    QuantStore<T, true> store_func(dev_ctx.template Alloc<int8_t>(out),
                                   shift->data<T>(),
                                   smooth->data<T>(),
                                   use_glu ? cols / 2 : cols,
                                   quant_round_type,
                                   quant_scale,
                                   quant_max_bound,
                                   quant_min_bound);
    ComputeImpl<T, Context, DequantLoad<T>, QuantStore<T, true>, int32_t>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  } else if (dequant_scales == nullptr && quant_scale > 0) {
    Load<T> load_func(x.data<T>());
    QuantStore<T, true> store_func(dev_ctx.template Alloc<int8_t>(out),
                                   shift->data<T>(),
                                   smooth->data<T>(),
                                   use_glu ? cols / 2 : cols,
                                   quant_round_type,
                                   quant_scale,
                                   quant_max_bound,
                                   quant_min_bound);
    ComputeImpl<T>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  } else if (dequant_scales != nullptr && quant_scale <= 0) {
    DequantLoad<T> load_func(
        x.data<int32_t>(), dequant_scales->data<float>(), cols);
    Store<T, true> store_func(dev_ctx.template Alloc<T>(out),
                              shift->data<T>(),
                              smooth->data<T>(),
                              use_glu ? cols / 2 : cols);
    ComputeImpl<T, Context, DequantLoad<T>, Store<T, true>, int32_t>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  } else {
    Load<T> load_func(x.data<T>());
    Store<T, true> store_func(dev_ctx.template Alloc<T>(out),
                              shift->data<T>(),
                              smooth->data<T>(),
                              use_glu ? cols / 2 : cols);
    ComputeImpl<T>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  }
}

struct NormalVersion {};
struct UnusedVersion {};

template <typename T>
struct DispatchDtypeTrait {
  using FuncVersion = NormalVersion;
};

template <>
struct DispatchDtypeTrait<int32_t> {
  using FuncVersion = UnusedVersion;
};

template <typename T, typename Context>
void DispatchWithDtype(const Context &dev_ctx,
                       const DenseTensor &x,
                       const paddle::optional<DenseTensor> &bias,
                       const paddle::optional<DenseTensor> &dequant_scales,
                       const paddle::optional<DenseTensor> &shift,
                       const paddle::optional<DenseTensor> &smooth,
                       const std::string &act_method,
                       int rows,
                       int cols,
                       float quant_scale,
                       int quant_round_type,
                       float quant_max_bound,
                       float quant_min_bound,
                       DenseTensor *out,
                       NormalVersion) {
  auto *bias_p = bias.get_ptr();
  auto *dequant_scales_p = dequant_scales.get_ptr();
  auto *shift_p = shift.get_ptr();
  auto *smooth_p = smooth.get_ptr();
  if (dequant_scales_p != nullptr) {
    if (shift_p != nullptr) {
      DispatchComputeImpl<T>(dev_ctx,
                             x,
                             bias_p,
                             dequant_scales_p,
                             shift_p,
                             smooth_p,
                             act_method,
                             rows,
                             cols,
                             quant_scale,
                             quant_round_type,
                             quant_max_bound,
                             quant_min_bound,
                             out);
    } else {
      DispatchComputeImpl<T>(dev_ctx,
                             x,
                             bias_p,
                             dequant_scales_p,
                             act_method,
                             rows,
                             cols,
                             quant_scale,
                             quant_round_type,
                             quant_max_bound,
                             quant_min_bound,
                             out);
    }
  } else {
    const T *bias_data = bias_p == nullptr ? nullptr : bias_p->data<T>();
    Load<T> load_func(x.data<T>());
    Store<T> store_func(dev_ctx.template Alloc<T>(out));
    ComputeImpl<T>(
        dev_ctx, bias_data, act_method, rows, cols, load_func, store_func);
  }
}

// (not use) only for registering int32_t
template <typename T, typename Context>
void DispatchWithDtype(const Context &dev_ctx,
                       const DenseTensor &x,
                       const paddle::optional<DenseTensor> &bias,
                       const paddle::optional<DenseTensor> &dequant_scales,
                       const paddle::optional<DenseTensor> &shift,
                       const paddle::optional<DenseTensor> &smooth,
                       const std::string &act_method,
                       int rows,
                       int cols,
                       float quant_scale,
                       int quant_round_type,
                       float quant_max_bound,
                       float quant_min_bound,
                       DenseTensor *out,
                       UnusedVersion) {}
#endif

template <typename T, typename Context>
void FusedBiasActKernel(const Context &dev_ctx,
                        const DenseTensor &x,
                        const paddle::optional<DenseTensor> &bias,
                        const paddle::optional<DenseTensor> &dequant_scales,
                        const paddle::optional<DenseTensor> &shift,
                        const paddle::optional<DenseTensor> &smooth,
                        const std::string &act_method,
                        const std::string &compute_dtype,
                        float quant_scale,
                        int quant_round_type,
                        float quant_max_bound,
                        float quant_min_bound,
                        DenseTensor *out) {
#ifndef PADDLE_WITH_HIP
F
freeliuzc 已提交
447 448
  int rows = x.dims()[0];
  int cols = x.dims()[1];
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
  if (x.dtype() == phi::DataType::INT32) {
    if (compute_dtype == "bf16") {
      DispatchWithDtype<phi::dtype::bfloat16, Context>(
          dev_ctx,
          x,
          bias,
          dequant_scales,
          shift,
          smooth,
          act_method,
          rows,
          cols,
          quant_scale,
          quant_round_type,
          quant_max_bound,
          quant_min_bound,
          out,
          typename DispatchDtypeTrait<phi::dtype::bfloat16>::FuncVersion{});

    } else if (compute_dtype == "fp16") {
      DispatchWithDtype<phi::dtype::float16, Context>(
          dev_ctx,
          x,
          bias,
          dequant_scales,
          shift,
          smooth,
          act_method,
          rows,
          cols,
          quant_scale,
          quant_round_type,
          quant_max_bound,
          quant_min_bound,
          out,
          typename DispatchDtypeTrait<phi::dtype::float16>::FuncVersion{});
    } else if (compute_dtype == "fp32") {
      DispatchWithDtype<float, Context>(
          dev_ctx,
          x,
          bias,
          dequant_scales,
          shift,
          smooth,
          act_method,
          rows,
          cols,
          quant_scale,
          quant_round_type,
          quant_max_bound,
          quant_min_bound,
          out,
          typename DispatchDtypeTrait<float>::FuncVersion{});
    } else {
      PADDLE_THROW(phi::errors::InvalidArgument(
          "In the case of quantization enabled with Input(x) INT32, "
          "Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
          "but get compute_dtype (%s)",
          compute_dtype));
    }
  } else {
    DispatchWithDtype<T, Context>(
        dev_ctx,
        x,
        bias,
        dequant_scales,
        shift,
        smooth,
        act_method,
        rows,
        cols,
        quant_scale,
        quant_round_type,
        quant_max_bound,
        quant_min_bound,
        out,
        typename DispatchDtypeTrait<T>::FuncVersion{});
  }
#endif
}

}  // namespace fusion
}  // namespace phi

PD_REGISTER_KERNEL(fused_bias_act,
                   GPU,
                   ALL_LAYOUT,
                   phi::fusion::FusedBiasActKernel,
                   float,
                   phi::dtype::bfloat16,
                   phi::dtype::float16,
                   int32_t) {}