fused_adam_kernel.cu 19.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/fused_adam_kernel.h"
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include <vector>
#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/multi_tensor_apply.h"

namespace phi {

// This code is referenced from apex's multi_tensor_adam.cu.
// https://github.com/NVIDIA/apex

template <typename T, bool CPUBetaPows /*=true*/>
34
struct FusedAdamBetaPowInfo {
35
  using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
36
  FusedAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) {
37 38 39 40 41 42 43 44 45 46 47 48 49 50
    beta1pow_ = *beta1pow;
    beta2pow_ = *beta2pow;
  }

  DEVICE MPDType GetBeta1PowValue() const { return beta1pow_; }

  DEVICE MPDType GetBeta2PowValue() const { return beta2pow_; }

 private:
  MPDType beta1pow_;
  MPDType beta2pow_;
};

template <typename T>
51
struct FusedAdamBetaPowInfo<T, /*CPUBetaPows=*/false> {
52
  using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
53
  FusedAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) {
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    beta1pow_ = beta1pow;
    beta2pow_ = beta2pow;
  }

  DEVICE MPDType GetBeta1PowValue() const { return *beta1pow_; }

  DEVICE MPDType GetBeta2PowValue() const { return *beta2pow_; }

 private:
  const MPDType* __restrict__ beta1pow_;
  const MPDType* __restrict__ beta2pow_;
};

template <typename T,
          typename MT,
          int VecSize,
          bool IsMultiPrecision,
          bool IsCPUBetaPow,
          bool UseAdamW,
          int N,
          int MaxTensorSize,
          int MaxBlockSize>
76
struct FusedAdamFunctor {
77 78 79 80 81
  __device__ __forceinline__ void operator()(
      int chunk_size,
      const funcs::TensorAndBlockInfo<N, MaxTensorSize, MaxBlockSize>& t_info,
      MT beta1,
      MT beta2,
82
      FusedAdamBetaPowInfo<T, IsCPUBetaPow> beta_pow,
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
      MT epsilon,
      const MT* learning_rate,
      MT decay) const {
    MT lr = *learning_rate;
    MT beta1_pow = beta_pow.GetBeta1PowValue();
    MT beta2_pow = beta_pow.GetBeta2PowValue();
    T* __restrict__ p_ptr;
    const T* __restrict__ g_ptr;
    MT* __restrict__ mom1_ptr, * __restrict__ mom2_ptr;
    MT* __restrict__ mp_ptr;
    int n;

    {
      int chunk_id, tensor_id;
      t_info.GetChunkIdAndTensorId(&chunk_id, &tensor_id);

      n = t_info.sizes[tensor_id];
      int offset = chunk_id * chunk_size;
      g_ptr = static_cast<const T*>(t_info.grads[tensor_id]) + offset;
      p_ptr = static_cast<T*>(t_info.tensor_addrs[0][tensor_id]) + offset;
      mom1_ptr = static_cast<MT*>(t_info.tensor_addrs[1][tensor_id]) + offset;
      mom2_ptr = static_cast<MT*>(t_info.tensor_addrs[2][tensor_id]) + offset;
      mp_ptr =
          IsMultiPrecision
              ? static_cast<MT*>(t_info.tensor_addrs[3][tensor_id]) + offset
              : nullptr;

      n -= offset;
      if (n > chunk_size) {
        n = chunk_size;
      }
    }

    int stride = blockDim.x * VecSize;
    int idx = threadIdx.x * VecSize;

    for (; idx < n; idx += stride) {
      phi::AlignedVector<T, VecSize> g_vec;
      phi::AlignedVector<T, VecSize> p_vec;
      phi::AlignedVector<MT, VecSize> mp_vec;
      phi::AlignedVector<MT, VecSize> mom1_vec;
      phi::AlignedVector<MT, VecSize> mom2_vec;
      if (idx <= n - VecSize) {
        if (IsMultiPrecision) {
          phi::Load<MT, VecSize>(mp_ptr + idx, &mp_vec);
        } else {
          phi::Load<T, VecSize>(p_ptr + idx, &p_vec);
        }
        phi::Load<T, VecSize>(g_ptr + idx, &g_vec);
        phi::Load<MT, VecSize>(mom1_ptr + idx, &mom1_vec);
        phi::Load<MT, VecSize>(mom2_ptr + idx, &mom2_vec);
      } else {
        int size = n - idx;
        for (int j = 0; j < size; j++) {
          if (IsMultiPrecision) {
            mp_vec[j] = mp_ptr[idx + j];
          } else {
            p_vec[j] = p_ptr[idx + j];
          }
          g_vec[j] = g_ptr[idx + j];
          mom1_vec[j] = static_cast<MT>(mom1_ptr[idx + j]);
          mom2_vec[j] = static_cast<MT>(mom2_ptr[idx + j]);
        }
#pragma unroll
        for (int j = size; j < VecSize; j++) {
          g_vec[j] = T(0);
          p_vec[j] = T(0);
          mp_vec[j] = MT(0);
          mom1_vec[j] = MT(0);
          mom2_vec[j] = MT(0);
        }
      }

#pragma unroll
      for (int j = 0; j < VecSize; j++) {
        MT p = IsMultiPrecision ? mp_vec[j] : static_cast<MT>(p_vec[j]);
        UpdateMoments(&mom1_vec[j],
                      &mom2_vec[j],
                      static_cast<MT>(g_vec[j]),
                      beta1,
                      beta2);
        mp_vec[j] = UpdateParameter(p,
                                    mom1_vec[j],
                                    mom2_vec[j],
                                    beta1_pow,
                                    beta2_pow,
                                    lr,
                                    epsilon,
                                    decay);
      }

      if (idx <= n - VecSize) {
        phi::Store<MT, VecSize>(mom1_vec, mom1_ptr + idx);
        phi::Store<MT, VecSize>(mom2_vec, mom2_ptr + idx);
        if (IsMultiPrecision) {
          phi::Store<MT, VecSize>(mp_vec, mp_ptr + idx);
        }
        for (int j = 0; j < VecSize; j++) {
          p_ptr[idx + j] = static_cast<T>(mp_vec[j]);
        }
      } else {
        int size = n - idx;
        for (int j = 0; j < size; j++) {
          if (IsMultiPrecision) {
            mp_ptr[idx + j] = mp_vec[j];
          }
          p_ptr[idx + j] = static_cast<T>(mp_vec[j]);
          mom1_ptr[idx + j] = mom1_vec[j];
          mom2_ptr[idx + j] = mom2_vec[j];
        }
      }
    }
  }

 private:
  static __device__ __forceinline__ void UpdateMoments(
      MT* __restrict__ mom1_ptr,
      MT* __restrict__ mom2_ptr,
      MT g,
      MT beta1,
      MT beta2) {
    MT mom1 = static_cast<MT>(mom1_ptr[0]);
    MT mom2 = static_cast<MT>(mom2_ptr[0]);
    mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
    mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;

    mom1_ptr[0] = mom1;
    mom2_ptr[0] = mom2;
  }

  static __device__ __forceinline__ MT UpdateParameter(MT p,
                                                       MT mom1,
                                                       MT mom2,
                                                       MT beta1_pow,
                                                       MT beta2_pow,
                                                       MT lr,
                                                       MT epsilon,
                                                       MT decay) {
    if (UseAdamW) {
      p *= (static_cast<MT>(1.0) - lr * decay);
    }
    MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
    p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
    return p;
  }
};

template <typename T, int N>
__global__ void UpdateBetaPowGroup(
    Array<T*, N> beta1_pow, Array<T*, N> beta2_pow, T beta1, T beta2, int n) {
  auto idx = threadIdx.x;
  if (idx < n) {
    beta1_pow[idx][0] *= beta1;
    beta2_pow[idx][0] *= beta2;
  }
}

template <typename Context>
static void CopyTensorIfDifferent(const Context& dev_ctx,
                                  const std::vector<const DenseTensor*>& src,
                                  const std::vector<DenseTensor*>& dst,
                                  bool use_src_place = false) {
  for (size_t i = 0; i < src.size(); ++i) {
    if (src[i] != dst[i]) {
      VLOG(10) << "Copy Tensor " << i;
      phi::Place place = (use_src_place ? src[i]->place() : dev_ctx.GetPlace());
      phi::Copy<Context>(dev_ctx, *(src[i]), place, false, dst[i]);
    }
  }
}

template <typename T, typename TensorT>
static int GetVecSizeFromTensors(const std::vector<TensorT*>& tensors,
                                 int vec_size = 4) {
  for (const auto* t : tensors) {
    vec_size = min(vec_size, GetVectorizedSize(t->template data<T>()));
  }
  return vec_size;
}

template <typename T, typename Context>
264
void FusedAdamKernel(
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
    const Context& dev_ctx,
    const std::vector<const DenseTensor*>& params,
    const std::vector<const DenseTensor*>& grads,
    const DenseTensor& learning_rate,
    const std::vector<const DenseTensor*>& moments1,
    const std::vector<const DenseTensor*>& moments2,
    const std::vector<const DenseTensor*>& beta1_pows,
    const std::vector<const DenseTensor*>& beta2_pows,
    const paddle::optional<std::vector<const DenseTensor*>>& master_params,
    const paddle::optional<DenseTensor>& skip_update,
    const Scalar& beta1,
    const Scalar& beta2,
    const Scalar& epsilon,
    int chunk_size,
    float weight_decay,
    bool use_adamw,
    bool multi_precision,
    bool use_global_beta_pow,
    std::vector<DenseTensor*> params_out,
    std::vector<DenseTensor*> moments1_out,
    std::vector<DenseTensor*> moments2_out,
    std::vector<DenseTensor*> beta1_pows_out,
    std::vector<DenseTensor*> beta2_pows_out,
    std::vector<DenseTensor*> master_params_out) {
  using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;

  auto n = params.size();
  auto beta1_pow_first = beta1_pows[0];
  auto beta2_pow_first = beta2_pows[0];

  for (int i = 1; i < beta1_pows.size(); i++) {
    PADDLE_ENFORCE_EQ(beta1_pow_first->place(),
                      beta1_pows[i]->place(),
                      phi::errors::InvalidArgument(
                          "All Beta1Pow must be in the same place."));
    PADDLE_ENFORCE_EQ(beta2_pow_first->place(),
                      beta2_pows[i]->place(),
                      phi::errors::InvalidArgument(
                          "All Beta2Pow must be in the same place."));
  }

  PADDLE_ENFORCE_EQ(
      beta1_pow_first->place(),
      beta2_pow_first->place(),
      phi::errors::InvalidArgument(
          "Input(Beta1Pows) and Input(Beta2Pows) must be in the same place."));

  bool is_cpu_betapow = (beta1_pow_first->place() == CPUPlace());

  VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;

  CopyTensorIfDifferent(dev_ctx, params, params_out);
  CopyTensorIfDifferent(dev_ctx, moments1, moments1_out);
  CopyTensorIfDifferent(dev_ctx, moments2, moments2_out);
  CopyTensorIfDifferent(dev_ctx, beta1_pows, beta1_pows_out, true);
  CopyTensorIfDifferent(dev_ctx, beta2_pows, beta2_pows_out, true);
  if (master_params) {
    CopyTensorIfDifferent(dev_ctx, master_params.get(), master_params_out);
  }

  bool skip_update_value = false;
  if (skip_update.is_initialized()) {
    PADDLE_ENFORCE_EQ(
        skip_update->numel(),
        1,
        errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d",
                                skip_update->numel()));
    DenseTensor skip_update_tensor;
    phi::Copy(
        dev_ctx, skip_update.get(), CPUPlace(), false, &skip_update_tensor);
    skip_update_value = skip_update_tensor.data<bool>()[0];
    VLOG(4) << "skip_update_value:" << skip_update_value;
  }

  // skip_update=true
  if (skip_update_value) {
    VLOG(4) << "Adam skip update";
    return;
  }

  MPDType beta1_tmp = beta1.to<MPDType>();
  MPDType beta2_tmp = beta2.to<MPDType>();

  std::vector<std::vector<DenseTensor*>> input_vector;
  input_vector.reserve(4);

  input_vector.push_back(params_out);
  input_vector.push_back(moments1_out);
  input_vector.push_back(moments2_out);
  if (multi_precision) {
    input_vector.push_back(master_params_out);
  }

  VLOG(4) << "use_adamw: " << use_adamw;
  VLOG(4) << "multi_precision: " << multi_precision;

#define PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(                       \
    __multi_precision, __is_cpu_betapow, __use_adamw, __vec_size)            \
  do {                                                                       \
    constexpr int kInputNum = __multi_precision ? 5 : 4;                     \
    constexpr int kMaxTensorSize = __multi_precision ? 48 : 60;              \
    constexpr int kMaxBlockSize = __multi_precision ? 320 : 320;             \
    constexpr int kBlockSize = 512;                                          \
368
    FusedAdamBetaPowInfo<T, __is_cpu_betapow> beta_pow_info(                 \
369
        beta1_pow_first->data<MPDType>(), beta2_pow_first->data<MPDType>()); \
370 371 372 373 374 375 376 377 378
    FusedAdamFunctor<T,                                                      \
                     MPDType,                                                \
                     __vec_size,                                             \
                     __multi_precision,                                      \
                     __is_cpu_betapow,                                       \
                     __use_adamw,                                            \
                     kInputNum,                                              \
                     kMaxTensorSize,                                         \
                     kMaxBlockSize>                                          \
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
        functor;                                                             \
    funcs::LaunchMultiTensorApplyKernel<kInputNum,                           \
                                        kMaxTensorSize,                      \
                                        kMaxBlockSize>(                      \
        dev_ctx,                                                             \
        kBlockSize,                                                          \
        ((chunk_size + __vec_size - 1) / __vec_size) * __vec_size,           \
        input_vector,                                                        \
        grads,                                                               \
        functor,                                                             \
        beta1_tmp,                                                           \
        beta2_tmp,                                                           \
        beta_pow_info,                                                       \
        epsilon.to<MPDType>(),                                               \
        learning_rate.data<MPDType>(),                                       \
        static_cast<MPDType>(weight_decay));                                 \
  } while (0)

#define PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(__vec_size) \
  case __vec_size: {                                         \
    if (multi_precision) {                                   \
      if (is_cpu_betapow) {                                  \
        if (use_adamw) {                                     \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              true, true, true, __vec_size);                 \
        } else {                                             \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              true, true, false, __vec_size);                \
        }                                                    \
      } else {                                               \
        if (use_adamw) {                                     \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              true, false, true, __vec_size);                \
        } else {                                             \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              true, false, false, __vec_size);               \
        }                                                    \
      }                                                      \
    } else {                                                 \
      if (is_cpu_betapow) {                                  \
        if (use_adamw) {                                     \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              false, true, true, __vec_size);                \
        } else {                                             \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              false, true, false, __vec_size);               \
        }                                                    \
      } else {                                               \
        if (use_adamw) {                                     \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              false, false, true, __vec_size);               \
        } else {                                             \
          PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE(     \
              false, false, false, __vec_size);              \
        }                                                    \
      }                                                      \
    }                                                        \
  } break

  int vec_size = GetVecSizeFromTensors<T>(params_out);
  vec_size = GetVecSizeFromTensors<MPDType>(moments1_out, vec_size);
  vec_size = GetVecSizeFromTensors<MPDType>(moments2_out, vec_size);
  if (master_params) {
    vec_size = GetVecSizeFromTensors<MPDType>(master_params_out, vec_size);
  }

  switch (vec_size) {
    PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(4);
    PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(2);
    PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(1);
    default:
      PADDLE_THROW(
          errors::InvalidArgument("Unsupported vectorized size %d", vec_size));
      break;
  }

  if (!use_global_beta_pow) {
    if (is_cpu_betapow) {
      for (size_t i = 0; i < n; i++) {
        VLOG(10) << "CPU Update BetaPow here...";
        auto* beta1_ptr =
            dev_ctx.template HostAlloc<MPDType>(beta1_pows_out[i]);
        (*beta1_ptr) *= beta1_tmp;

        auto* beta2_ptr =
            dev_ctx.template HostAlloc<MPDType>(beta2_pows_out[i]);
        (*beta2_ptr) *= beta2_tmp;
      }
    } else {
      constexpr size_t kGroupSize = 32;
      auto group_num = (n + kGroupSize - 1) / kGroupSize;
      VLOG(10) << "GPU Update BetaPow here...";
      for (size_t i = 0; i < group_num; ++i) {
        size_t start = i * kGroupSize;
        size_t end = std::min((i + 1) * kGroupSize, n);
        Array<MPDType*, kGroupSize> beta1_ptrs, beta2_ptrs;
        for (size_t j = start; j < end; ++j) {
          size_t idx = j - start;
          beta1_ptrs[idx] = dev_ctx.template Alloc<MPDType>(beta1_pows_out[j]);
          beta2_ptrs[idx] = dev_ctx.template Alloc<MPDType>(beta2_pows_out[j]);
        }
        UpdateBetaPowGroup<MPDType, kGroupSize>
            <<<1, kGroupSize, 0, dev_ctx.stream()>>>(
                beta1_ptrs, beta2_ptrs, beta1_tmp, beta2_tmp, end - start);
      }
    }
  }
}

}  // namespace phi

490
PD_REGISTER_KERNEL(fused_adam,
491 492
                   GPU,
                   ALL_LAYOUT,
493
                   phi::FusedAdamKernel,
494
                   phi::dtype::float16,
495
                   phi::dtype::bfloat16,
496 497 498 499 500 501
                   float,
                   double) {
  // Skip beta1_pow, beta2_pow, skip_update data transform
  kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
  kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
  kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
502 503 504 505 506
  kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
  kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
  kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED);
  kernel->OutputAt(4).SetDataType(phi::DataType::UNDEFINED);
  kernel->OutputAt(5).SetDataType(phi::DataType::UNDEFINED);
507
}