distributed_fused_lamb_op.cu 72.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>
16

17
#include "paddle/fluid/memory/buffer.h"
18
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
19 20
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
21
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
22 23 24 25
#include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/string/string_helper.h"
26
#include "paddle/phi/core/utils/data_type.h"
27
#include "paddle/phi/kernels/funcs/aligned_vector.h"
28 29 30 31 32 33 34 35

#ifdef __NVCC__
#include "cub/cub.cuh"
#include "math.h"  // NOLINT
#endif

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
36

37 38 39 40 41 42 43 44 45 46
#include "math.h"  // NOLINT
namespace cub = hipcub;
#endif

namespace paddle {
namespace operators {

template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type;

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
  static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}

template <typename T, int BlockDim, int VecSize>
struct L2NormFunctor {
  DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size,
                         const T *x, MasterT<T> *y, int max_chunk_num) const {
    using MT = MasterT<T>;
    const T *ptr = x + offset;

    using BlockReduce = cub::BlockReduce<MT, BlockDim>;
    __shared__ typename BlockReduce::TempStorage storage;

    MT square_sum = static_cast<MT>(0);
    int i;
    for (i = threadIdx.x * VecSize; i + VecSize <= size;
         i += (BlockDim * VecSize)) {
71 72
      phi::AlignedVector<T, VecSize> tmp_vec;
      phi::Load(ptr + i, &tmp_vec);
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
#pragma unroll
      for (int j = 0; j < VecSize; ++j) {
        auto tmp = static_cast<MT>(tmp_vec[j]);
        square_sum += (tmp * tmp);
      }
    }

    for (; i < size; ++i) {
      auto tmp = static_cast<MT>(ptr[i]);
      square_sum += (tmp * tmp);
    }

    square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum());
    if (threadIdx.x == 0) {
      y[tensor_id * max_chunk_num + chunk_id] = square_sum;
    }
  }
};

92
template <typename InT, typename OutT, int BlockDim>
93 94 95 96 97 98 99 100 101 102 103 104
static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
    const InT *x, OutT *y, int max_chunk_num) {
  int tensor_id = blockIdx.x;
  x += (tensor_id * max_chunk_num);
  using BlockReduce = cub::BlockReduce<InT, BlockDim>;
  __shared__ typename BlockReduce::TempStorage storage;
  InT sum = static_cast<InT>(0);
  for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) {
    sum += x[i];
  }
  sum = BlockReduce(storage).Reduce(sum, cub::Sum());
  if (threadIdx.x == 0) {
105
    y[blockIdx.x] = static_cast<OutT>(sum);
106 107 108 109 110 111 112 113 114 115
  }
}

template <typename T>
static int GetChunkedVecSize(const T *ptr, int chunk_size) {
  static_assert(!std::is_same<T, void>::value, "T cannot be void.");

  constexpr int max_load_bits = 128;
  int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
  auto address = reinterpret_cast<uintptr_t>(ptr);
116 117 118
  constexpr int vec8 = alignof(phi::AlignedVector<T, 8>);
  constexpr int vec4 = alignof(phi::AlignedVector<T, 4>);
  constexpr int vec2 = alignof(phi::AlignedVector<T, 2>);
119
  chunk_size *= sizeof(T);
120 121 122 123 124 125 126 127 128 129 130
  if (address % vec8 == 0 && chunk_size % vec8 == 0) {
    return std::min(8, valid_vec_size);
  } else if (address % vec4 == 0 && chunk_size % vec4 == 0) {
    return std::min(4, valid_vec_size);
  } else if (address % vec2 == 0 && chunk_size % vec2 == 0) {
    return std::min(2, valid_vec_size);
  } else {
    return 1;
  }
}

131 132 133 134 135
#define PD_VEC_LAUNCH_KERNEL_CASE(__vec_size, ...) \
  case __vec_size: {                               \
    constexpr int kVecSize = __vec_size;           \
    __VA_ARGS__;                                   \
    break;                                         \
136 137
  }

138 139 140 141 142 143 144 145
#define PD_VEC_LAUNCH_KERNEL(__vec_size, ...)    \
  do {                                           \
    switch (__vec_size) {                        \
      PD_VEC_LAUNCH_KERNEL_CASE(8, __VA_ARGS__); \
      PD_VEC_LAUNCH_KERNEL_CASE(4, __VA_ARGS__); \
      PD_VEC_LAUNCH_KERNEL_CASE(2, __VA_ARGS__); \
      PD_VEC_LAUNCH_KERNEL_CASE(1, __VA_ARGS__); \
    }                                            \
146 147 148
  } while (0)

// TODO(zengjinle): which chunk_size is better?
149 150
template <typename InT, typename OutT, int MaxTensorNumPerLaunch = 160,
          int MaxChunkNumPerLaunch = 780>
151 152 153 154 155 156 157 158
static void MultiTensorL2Norm(const platform::CUDAPlace &place,
                              gpuStream_t stream, const InT *x,
                              const int *offsets, int n, OutT *y,
                              int chunk_size = 65536) {
  if (n <= 0) return;

  constexpr int kNumTensor = MaxTensorNumPerLaunch;
  constexpr int kNumChunk = MaxChunkNumPerLaunch;
159
  constexpr int kBlockDim = 512;
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

  int max_chunk_num = -1;
  int vec_size = 8;
  int total_chunk_num = 0;
  for (int i = 0; i < n; ++i) {
    vec_size = std::min(
        vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size));
    int length = offsets[i + 1] - offsets[i];
    auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size;
    max_chunk_num = std::max(max_chunk_num, tmp_chunk_num);
    total_chunk_num += tmp_chunk_num;
  }

  VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num
          << " , total_chunk_num = " << total_chunk_num
          << " , tensor_num = " << n;

  using MT = MasterT<InT>;
  memory::Buffer tmp_out(place);
  auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
  FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);

182 183 184 185 186 187 188 189
#define PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL                            \
  do {                                                                         \
    using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>;                  \
    VLOG(10) << __func__ << " " << typeid(InT).name()                          \
             << " VecSize = " << kVecSize;                                     \
    MultiTensorApply<FunctorT, kNumTensor, kNumChunk>(                         \
        FunctorT(), stream, offsets, n, chunk_size, kBlockDim, x, tmp_out_ptr, \
        max_chunk_num);                                                        \
190 191
  } while (0)

192 193
  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL);
#undef PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL
194

195 196
  MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim>
      <<<n, kBlockDim, 0, stream>>>(tmp_out_ptr, y, max_chunk_num);
197 198
}

199 200 201 202 203 204 205 206 207
template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm(
    const framework::ExecutionContext &ctx, const float *param_square_norm,
    const float *trust_ratio_div_square_norm) {
  if (!VLOG_IS_ON(LogLevel)) return;

  auto tensors = ctx.MultiInput<framework::Tensor>("Param");
  if (tensors.empty()) return;

208 209
  const auto *order = ctx.Input<framework::Tensor>("ParamOrder")->data<int>();

210 211 212 213 214 215 216
  size_t n = tensors.size();
  auto place = tensors[0]->place();

  auto pn_vec = ToVector(param_square_norm, n, place);
  auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place);

  const auto &names = ctx.GetOp().Inputs("Param");
217 218
  for (size_t i = 0; i < n; ++i) {
    auto idx = order[i];
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx]
                   << " pn = " << pn_vec[i] << " , tn = " << tn_vec[i];
  }
}

static bool IsFinite(const platform::CUDADeviceContext &dev_ctx,
                     const float *ptr) {
  auto stream = dev_ctx.stream();
  float cpu_value;
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(&cpu_value, ptr, sizeof(float),
                                            hipMemcpyDeviceToHost, stream));
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&cpu_value, ptr, sizeof(float),
                                             cudaMemcpyDeviceToHost, stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#endif
  LOG(INFO) << "NAN_INF indicator value: " << cpu_value;
  return isfinite(cpu_value);
}

template <typename T>
static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx,
                                  const char *in_name,
                                  int64_t *numel = nullptr) {
  const auto *in_tensor = ctx.Input<framework::Tensor>(in_name);
  PADDLE_ENFORCE_NOT_NULL(in_tensor, platform::errors::InvalidArgument(
                                         "Input(%s) cannot be NULL.", in_name));
  if (in_tensor->IsInitialized()) {
    if (numel) *numel = in_tensor->numel();
    return in_tensor->data<T>();
  } else {
    if (numel) *numel = 0;
    return nullptr;
  }
}

template <typename T, bool AllowNotExist = false>
static T *GetSameInOutTensorPtr(const framework::ExecutionContext &ctx,
                                const platform::Place &place,
                                const char *in_name, const char *out_name,
                                int64_t *numel = nullptr) {
  const auto *in_tensor = ctx.Input<framework::Tensor>(in_name);
  if (in_tensor == nullptr || !in_tensor->IsInitialized()) {
    PADDLE_ENFORCE_EQ(AllowNotExist, true,
                      platform::errors::InvalidArgument(
                          "Input(%s) cannot be NULL.", in_name));
    if (numel) *numel = 0;
    return nullptr;
  }

  auto *out_tensor = ctx.Output<framework::Tensor>(out_name);
  PADDLE_ENFORCE_NOT_NULL(in_tensor, platform::errors::InvalidArgument(
                                         "Input(%s) cannot be NULL.", in_name));
  PADDLE_ENFORCE_NOT_NULL(out_tensor,
                          platform::errors::InvalidArgument(
                              "Output(%s) cannot be NULL.", out_name));
  const T *in_data = in_tensor->data<T>();
  T *out_data = out_tensor->mutable_data<T>(place);
  PADDLE_ENFORCE_EQ(in_data, out_data,
                    platform::errors::InvalidArgument(
                        "Input(%s) and Output(%s) must be the same Tensor.",
                        in_name, out_name));
  if (numel) *numel = out_tensor->numel();
  return out_data;
}

template <typename T>
struct SquareFunctor {
  HOSTDEVICE MasterT<T> operator()(T x) const {
    auto y = static_cast<MasterT<T>>(x);
    return y * y;
  }
};

template <typename T>
struct IsNanInfFunctor {
  HOSTDEVICE bool operator()(T x) const { return !isfinite(x); }
};

struct OrFunctor {
  HOSTDEVICE bool operator()(bool x, bool y) const { return x || y; }
};

struct AndFunctor {
  HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; }
};

S
sneaxiy 已提交
308
template <typename T1, typename T2, int VecSize>
309 310 311 312 313 314
static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x,
                                       const T2 *__restrict__ scale,
                                       T1 *__restrict__ y, int num) {
  static_assert(sizeof(T1) <= sizeof(T2),
                "sizeof(T1) must be not greater than sizeof(T2).");
  T2 s = scale[0];
S
sneaxiy 已提交
315 316 317 318 319

  int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  int stride = blockDim.x * gridDim.x * VecSize;

  for (; i + VecSize <= num; i += stride) {
320 321
    phi::AlignedVector<T1, VecSize> x_vec;
    phi::AlignedVector<T1, VecSize> y_vec;
S
sneaxiy 已提交
322

323
    phi::Load(x + i, &x_vec);
S
sneaxiy 已提交
324 325 326 327
#pragma unroll
    for (int j = 0; j < VecSize; ++j) {
      y_vec[j] = static_cast<T1>(static_cast<T2>(x_vec[j]) * s);
    }
328
    phi::Store(y_vec, y + i);
S
sneaxiy 已提交
329 330 331
  }

  for (; i < num; ++i) {
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    y[i] = static_cast<T1>(static_cast<T2>(x[i]) * s);
  }
}

template <typename T>
static __global__ void AddToCUDAKernel(const T *__restrict__ x,
                                       T *__restrict__ y) {
  y[0] += x[0];
}

// If clip before allreduce,
// coeff = global_scale * max_global_grad_norm / (1e-6 + sqrt(square_grad_norm)
// * rescale_grad)
// if coeff >= 1 or coeff is Nan/Inf, scale = 1.0
// else scale = coeff
template <typename T1, typename T2>
static __global__ void CalcGradNormClipBeforeAllReduceScale(
    const T1 *__restrict__ global_scale, T1 max_global_grad_norm,
    const T1 *__restrict__ square_grad_norm, T1 *__restrict__ out1,
    T2 *__restrict__ out2, T1 clip_rescale_grad) {
352
  T1 grad_norm = static_cast<T1>(sqrtf(*square_grad_norm)) * clip_rescale_grad;
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
  T1 scale = global_scale[0] * max_global_grad_norm / (1e-6 + grad_norm);
  bool found_nan_inf = !isfinite(scale);
  if (scale >= 1 || found_nan_inf) {
    scale = static_cast<T1>(1.0);
  }

  if (out1) {
    *out1 = scale;
  }
  if (out2) {
    *out2 = static_cast<T2>(scale);
  }
}

static __global__ void SetNanInfValueCUDAKernelOneFlag(const bool *in_flag_p,
                                                       float *out_p) {
  *out_p = (*in_flag_p) ? __int_as_float(0x7fffffffU) : 0.0f;
}

static __global__ void SetNanInfValueCUDAKernelTwoFlag(const bool *in_flag_p_1,
                                                       const bool *in_flag_p_2,
                                                       float *out_p) {
  *out_p =
      ((*in_flag_p_1) || (*in_flag_p_2)) ? __int_as_float(0x7fffffffU) : 0.0f;
}

379 380
template <typename T, typename GradT, int VecSize>
static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
381 382
    const T *__restrict__ param_p, const GradT *__restrict__ grad_p,
    const T *__restrict__ square_grad_norm_p,
383
    const T *__restrict__ global_scale, const T *__restrict__ beta1pow_p,
384
    const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p,
385 386 387
    T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p,
    bool *__restrict__ found_inf, int64_t *__restrict__ step, T weight_decay,
    int weight_decay_end_numel, T beta1, T beta2, T epsilon,
388
    T max_global_grad_norm, int num, T rescale_grad) {
389
  T square_grad_norm = *square_grad_norm_p;
390 391 392 393 394 395 396
  bool need_update_found_inf =
      (found_inf && threadIdx.x == 0 && blockIdx.x == 0);
  if (!isfinite(square_grad_norm)) {
    if (need_update_found_inf) *found_inf = true;
    return;
  } else if (need_update_found_inf) {
    *found_inf = false;
397
    ++(*step);
398
  }
399 400 401 402 403 404 405 406 407 408 409 410 411

  T scale = rescale_grad / global_scale[0];
  if (max_global_grad_norm > 0) {
    T clip_scale =
        max_global_grad_norm / (sqrtf(square_grad_norm) * scale + 1e-6);
    if (clip_scale < static_cast<T>(1)) {
      scale *= clip_scale;
    }
  }

  T one_minus_beta1pow = 1 - beta1pow_p[0];
  T one_minus_beta2pow = 1 - beta2pow_p[0];

412 413 414 415
  int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  int stride = blockDim.x * gridDim.x * VecSize;

  for (; i + VecSize <= num; i += stride) {
416 417 418 419 420
    phi::AlignedVector<T, VecSize> param_vec;
    phi::AlignedVector<GradT, VecSize> grad_vec;
    phi::AlignedVector<T, VecSize> mom1_vec;
    phi::AlignedVector<T, VecSize> mom2_vec;
    phi::AlignedVector<T, VecSize> trust_ratio_div_vec;
421 422 423

    T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay;
    if (cur_weight_decay != static_cast<T>(0.0)) {
424
      phi::Load(param_p + i, &param_vec);
425 426 427 428 429 430
    } else {
#pragma unroll
      for (int j = 0; j < VecSize; ++j) {
        param_vec[j] = static_cast<T>(0);
      }
    }
431 432 433
    phi::Load(grad_p + i, &grad_vec);
    phi::Load(mom1_p + i, &mom1_vec);
    phi::Load(mom2_p + i, &mom2_vec);
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455

#define PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(__param, __grad, __mom1, __mom2,    \
                                           __trust_ratio_div, __idx)           \
  T p = __param[__idx];                                                        \
  T g = static_cast<T>(__grad[__idx]) * scale;                                 \
  T mom1 = __mom1[__idx];                                                      \
  T mom2 = __mom2[__idx];                                                      \
  mom1 = beta1 * mom1 + (1 - beta1) * g;                                       \
  mom2 = beta2 * mom2 + (1 - beta2) * g * g;                                   \
  T mom1_unbiased = mom1 / one_minus_beta1pow;                                 \
  T mom2_unbiased = mom2 / one_minus_beta2pow;                                 \
  __trust_ratio_div[__idx] =                                                   \
      mom1_unbiased / (sqrtf(mom2_unbiased) + epsilon) + cur_weight_decay * p; \
  __mom1[__idx] = mom1;                                                        \
  __mom2[__idx] = mom2;

#pragma unroll
    for (int j = 0; j < VecSize; ++j) {
      PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(param_vec, grad_vec, mom1_vec,
                                         mom2_vec, trust_ratio_div_vec, j);
    }

456 457 458
    phi::Store(mom1_vec, mom1_p + i);
    phi::Store(mom2_vec, mom2_p + i);
    phi::Store(trust_ratio_div_vec, trust_ratio_div_p + i);
459 460 461 462 463 464
  }

  for (; i < num; ++i) {
    T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay;
    PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(param_p, grad_p, mom1_p, mom2_p,
                                       trust_ratio_div_p, i);
465 466 467
  }
}

468 469 470 471 472
template <typename T, typename GradT>
static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
    const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n,
    const T *param_p, const GradT *grad_p, const T *square_grad_norm_p,
    const T *global_scale, const T *beta1pow_p, const T *beta2pow_p, T *mom1_p,
473 474
    T *mom2_p, T *trust_ratio_div_p, bool *found_inf_p, int64_t *step,
    T weight_decay, int weight_decay_end_idx, T beta1, T beta2, T epsilon,
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
    T max_global_grad_norm, T rescale_grad) {
  if (n <= 0) return;
  int numel = offsets[n] - offsets[0];
  PADDLE_ENFORCE_GE(weight_decay_end_idx, 0,
                    platform::errors::InvalidArgument(
                        "The weight decay end index should be >= 0."));
  PADDLE_ENFORCE_LE(weight_decay_end_idx, n,
                    platform::errors::InvalidArgument(
                        "The weight decay end index should be < %d.", n));
  auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0];

  int vec_size = GetChunkedVecSize(param_p, 0);
  vec_size = std::min(vec_size, GetChunkedVecSize(grad_p, 0));
  vec_size = std::min(vec_size, GetChunkedVecSize(mom1_p, 0));
  vec_size = std::min(vec_size, GetChunkedVecSize(mom2_p, 0));
  vec_size = std::min(vec_size, GetChunkedVecSize(trust_ratio_div_p, 0));
  for (int i = 0; i < n; ++i) {
    auto length = offsets[i + 1] - offsets[i];
    while (length % vec_size != 0) {
      vec_size /= 2;
    }
  }

  VLOG(1) << __func__ << " VecSize = " << vec_size;

  auto stream = dev_ctx.stream();
  auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
502 503 504 505 506 507 508 509 510
  if (found_inf_p == nullptr) {
    PADDLE_ENFORCE_EQ(
        step, nullptr,
        platform::errors::InvalidArgument(
            "Output(Step) cannot be updated twice in one mini-batch."));
  } else {
    PADDLE_ENFORCE_NOT_NULL(step, platform::errors::InvalidArgument(
                                      "Output(Step) cannot be nullptr."));
  }
511

512 513 514 515 516 517 518 519
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL                             \
  do {                                                                        \
    UpdateLambMomentAndTrustRatioDivCUDAKernel<T, GradT, kVecSize>            \
        <<<config.block_per_grid, config.thread_per_block, 0, stream>>>(      \
            param_p, grad_p, square_grad_norm_p, global_scale, beta1pow_p,    \
            beta2pow_p, mom1_p, mom2_p, trust_ratio_div_p, found_inf_p, step, \
            weight_decay, weight_decay_end_numel, beta1, beta2, epsilon,      \
            max_global_grad_norm, numel, rescale_grad);                       \
520 521 522 523 524 525
  } while (0)

  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL);
#undef PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL
}

526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
template <typename T, bool NeedUpdate /*=true*/>
struct LambBetaPowUpdateOnceHelper {
  LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
    PADDLE_ENFORCE_NOT_NULL(beta1pow,
                            platform::errors::InvalidArgument(
                                "The beta1pow should not be nullptr."));
    PADDLE_ENFORCE_NOT_NULL(beta2pow,
                            platform::errors::InvalidArgument(
                                "The beta2pow should not be nullptr."));
    beta1pow_ = beta1pow;
    beta2pow_ = beta2pow;
    beta1_ = beta1;
    beta2_ = beta2;
  }

  HOSTDEVICE void UpdateBetaPows() const {
    beta1pow_[0] *= beta1_;
    beta2pow_[0] *= beta2_;
  }

 private:
  T *__restrict__ beta1pow_;
  T *__restrict__ beta2pow_;
  T beta1_;
  T beta2_;
};

template <typename T>
struct LambBetaPowUpdateOnceHelper<T, false> {
  LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
    PADDLE_ENFORCE_EQ(
        beta1pow, nullptr,
        platform::errors::InvalidArgument("The beta1pow should be nullptr."));
    PADDLE_ENFORCE_EQ(
        beta2pow, nullptr,
        platform::errors::InvalidArgument("The beta2pow should be nullptr."));
  }

  HOSTDEVICE void UpdateBetaPows() const {}
};

template <typename T, bool HasMasterParam /*=true*/>
struct LambParamHelper {
  LambParamHelper(T *param, MasterT<T> *master_param) {
    constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
    PADDLE_ENFORCE_EQ(kIsSameType, false,
                      platform::errors::InvalidArgument(
                          "T must not be the same with MasterT<T>."));
    PADDLE_ENFORCE_NOT_NULL(master_param,
                            platform::errors::InvalidArgument(
                                "Master parameter must be provided."));
    param_ = param;
    master_param_ = master_param;
  }

581
  HOSTDEVICE T *__restrict__ ParamPtr() { return param_; }
582

583
  HOSTDEVICE MasterT<T> *__restrict__ MasterParamPtr() { return master_param_; }
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606

 private:
  T *__restrict__ param_;
  MasterT<T> *__restrict__ master_param_;
};

template <typename T>
struct LambParamHelper<T, false> {
  LambParamHelper(T *param, MasterT<T> *master_param) {
    constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
    PADDLE_ENFORCE_EQ(kIsSameType, true,
                      platform::errors::InvalidArgument(
                          "T must be the same with MasterT<T>."));
    if (master_param != nullptr) {
      PADDLE_ENFORCE_EQ(static_cast<void *>(param),
                        static_cast<void *>(master_param),
                        platform::errors::InvalidArgument(
                            "Master parameter must be nullptr or the same as "
                            "non-master parameter."));
    }
    param_ = param;
  }

607
  HOSTDEVICE T *__restrict__ ParamPtr() { return param_; }
608

609
  HOSTDEVICE constexpr MasterT<T> *MasterParamPtr() { return nullptr; }
610 611 612 613 614

 private:
  T *__restrict__ param_;
};

615 616 617 618 619 620 621
template <typename ParamT, bool HasMasterParam, bool NeedUpdateBetaPow,
          int VecSize>
struct LambUpdateParamAndBetaPowsFunctor {
  DEVICE void operator()(
      int tensor_id, int chunk_id, int offset, int size,
      LambParamHelper<ParamT, HasMasterParam> param_helper,
      const MasterT<ParamT> *trust_ratio_div, const MasterT<ParamT> *lr,
622
      const MasterT<ParamT> *param_square_norm,
623 624 625 626 627 628
      const MasterT<ParamT> *trust_ratio_div_square_norm, const bool *found_inf,
      LambBetaPowUpdateOnceHelper<MasterT<ParamT>, NeedUpdateBetaPow>
          betapow_helper) const {
    if (*found_inf) return;

    using MT = MasterT<ParamT>;
629

630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
    MT p_square_norm = param_square_norm[tensor_id];
    MT t_square_norm = trust_ratio_div_square_norm[tensor_id];
    MT lr_value = *lr;
    MT ratio = (p_square_norm != static_cast<MT>(0) &&
                        t_square_norm != static_cast<MT>(0)
                    ? lr_value * sqrtf(p_square_norm / t_square_norm)
                    : lr_value);

    int i;
    int stride = blockDim.x * VecSize;

    ParamT *param = param_helper.ParamPtr() + offset;
    MT *master_param = HasMasterParam ? param_helper.MasterParamPtr() + offset
                                      : param_helper.MasterParamPtr();
    trust_ratio_div += offset;

    for (i = threadIdx.x * VecSize; i + VecSize <= size; i += stride) {
647 648
      phi::AlignedVector<MT, VecSize> trust_ratio_div_vec;
      phi::Load(trust_ratio_div + i, &trust_ratio_div_vec);
649
      if (HasMasterParam) {
650 651 652
        phi::AlignedVector<MT, VecSize> master_param_vec;
        phi::Load(master_param + i, &master_param_vec);
        phi::AlignedVector<ParamT, VecSize> param_vec;
653 654 655 656 657 658
#pragma unroll
        for (int j = 0; j < VecSize; ++j) {
          MT p = master_param_vec[j] - ratio * trust_ratio_div_vec[j];
          master_param_vec[j] = p;
          param_vec[j] = static_cast<ParamT>(p);
        }
659 660
        phi::Store(master_param_vec, master_param + i);
        phi::Store(param_vec, param + i);
661
      } else {
662 663
        phi::AlignedVector<ParamT, VecSize> param_vec;
        phi::Load(param + i, &param_vec);
664 665 666 667 668
#pragma unroll
        for (int j = 0; j < VecSize; ++j) {
          MT p = static_cast<MT>(param_vec[j]) - ratio * trust_ratio_div_vec[j];
          param_vec[j] = static_cast<ParamT>(p);
        }
669
        phi::Store(param_vec, param + i);
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
      }
    }

    for (; i < size; ++i) {
      if (HasMasterParam) {
        MT p = master_param[i] - ratio * trust_ratio_div[i];
        master_param[i] = p;
        param[i] = static_cast<ParamT>(p);
      } else {
        MT p = static_cast<MT>(param[i]) - ratio * trust_ratio_div[i];
        param[i] = static_cast<ParamT>(p);
      }
    }

    if (NeedUpdateBetaPow && threadIdx.x == 0 && blockIdx.x == 0) {
      betapow_helper.UpdateBetaPows();
686 687
    }
  }
688
};
689

690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
// TODO(zengjinle): which block_dim and chunk_size would be better?
template <typename ParamT, int MaxTensorNumPerLaunch = 160,
          int MaxChunkNumPerLaunch = 780>
static void MultiTensorUpdateLambParamAndBetaPows(
    const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n,
    const MasterT<ParamT> *trust_ratio_div, const MasterT<ParamT> *lr,
    const MasterT<ParamT> *param_square_norm,
    const MasterT<ParamT> *trust_ratio_div_square_norm, const bool *found_inf,
    ParamT *param, MasterT<ParamT> *master_param, MasterT<ParamT> *beta1pow,
    MasterT<ParamT> *beta2pow, MasterT<ParamT> beta1, MasterT<ParamT> beta2,
    int chunk_size = 65536) {
  constexpr bool kHasMasterParam =
      !(std::is_same<ParamT, MasterT<ParamT>>::value);

  bool has_beta_pow = (beta1pow != nullptr);
  if (has_beta_pow) {
    PADDLE_ENFORCE_NOT_NULL(beta2pow, platform::errors::InvalidArgument(
                                          "Beta2Pow should not be nullptr."));
  } else {
709 710 711
    PADDLE_ENFORCE_EQ(
        beta2pow, nullptr,
        platform::errors::InvalidArgument("Beta2Pow should be nullptr."));
712 713
  }

714
  const int block_dim = 512;
715

716 717 718 719 720 721 722 723 724 725 726 727
  int vec_size = 8;
  for (int i = 0; i < n; ++i) {
    int offset = offsets[i] - offsets[0];
    vec_size =
        std::min(vec_size, GetChunkedVecSize(param + offset, chunk_size));
    if (kHasMasterParam) {
      vec_size = std::min(vec_size,
                          GetChunkedVecSize(master_param + offset, chunk_size));
    }
    vec_size = std::min(
        vec_size, GetChunkedVecSize(trust_ratio_div + offset, chunk_size));
  }
728

729
  VLOG(1) << __func__ << " VecSize = " << vec_size;
730

731 732
  constexpr auto kNumTensor = MaxTensorNumPerLaunch;
  constexpr auto kNumChunk = MaxChunkNumPerLaunch;
733

734 735 736 737 738 739 740 741 742 743 744 745 746 747
  auto stream = dev_ctx.stream();
#define PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(__has_beta_pow)            \
  do {                                                                         \
    using FunctorT =                                                           \
        LambUpdateParamAndBetaPowsFunctor<ParamT, kHasMasterParam,             \
                                          __has_beta_pow, kVecSize>;           \
    LambParamHelper<ParamT, kHasMasterParam> param_helper(param,               \
                                                          master_param);       \
    LambBetaPowUpdateOnceHelper<MasterT<ParamT>, __has_beta_pow>               \
        betapow_helper(beta1pow, beta2pow, beta1, beta2);                      \
    launcher.Launch(FunctorT(), param_helper, trust_ratio_div, lr,             \
                    param_square_norm, trust_ratio_div_square_norm, found_inf, \
                    betapow_helper);                                           \
  } while (0)
748

749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
#define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE            \
  do {                                                                  \
    auto callback =                                                     \
        [&](const MultiTensorLauncher<kNumTensor, kNumChunk> &launcher, \
            int launch_n) {                                             \
          if (has_beta_pow && launch_n == 0) {                          \
            PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true);          \
            beta1pow = nullptr;                                         \
            beta2pow = nullptr;                                         \
          } else {                                                      \
            PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false);         \
          }                                                             \
        };                                                              \
    MultiTensorApplyWithCallback<kNumTensor, kNumChunk>(                \
        stream, offsets, n, chunk_size, block_dim, callback);           \
764 765
  } while (0)

766 767
  PD_VEC_LAUNCH_KERNEL(vec_size,
                       PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE);
768

769 770
#undef PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW
#undef PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE
771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
}

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
                                           ncclComm_t comm, const void *scale,
                                           ncclRedOp_t *op) {
#if NCCL_VERSION_CODE >= 21100
  int ver;
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetVersion(&ver));
  if (ver >= 21100) {
    VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpCreatePreMulSum(
        op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm));
    return true;
  }
#endif
  VLOG(10) << "ncclRedOpCreatePreMulSum is not supported.";
  return false;
}

S
sneaxiy 已提交
791 792 793 794 795 796 797
template <typename T1, typename T2>
static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx,
                              const T1 *x, const T2 *scale, T1 *y, int n,
                              gpuStream_t stream) {
  int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
  auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);

798 799 800 801 802
#define PD_LAMB_VEC_SCALE_KERNEL_CASE                                    \
  do {                                                                   \
    ScaleCUDAKernel<T1, T2, kVecSize>                                    \
        <<<config.block_per_grid, config.thread_per_block, 0, stream>>>( \
            x, scale, y, n);                                             \
S
sneaxiy 已提交
803 804 805 806 807 808
  } while (0)

  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAMB_VEC_SCALE_KERNEL_CASE);
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE
}

809 810 811 812 813 814
template <typename T, bool UseReduceScatter>
static void NCCLSumWithScaleBase(const T *sendbuff, T *recvbuff,
                                 size_t recvcount, size_t nranks,
                                 ncclComm_t comm, gpuStream_t stream,
                                 const platform::CUDADeviceContext &dev_ctx,
                                 const T *scale = nullptr) {
815 816 817 818 819
  static_assert(std::is_same<T, float>::value ||
                    std::is_same<T, platform::float16>::value,
                "T must be either float32 or float16.");
  if (recvcount == 0) return;

820
  auto numel = UseReduceScatter ? (recvcount * nranks) : recvcount;
821 822 823 824 825
  if (comm == nullptr) {
    if (scale != nullptr) {
      PADDLE_ENFORCE_EQ(nranks, 1,
                        platform::errors::InvalidArgument(
                            "nranks must be 1 when scale != nullptr."));
826
      LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream);
827 828 829 830 831 832 833 834 835 836 837 838
    }
    return;
  }

  ncclRedOp_t op = ncclSum;
  ncclDataType_t dtype =
      std::is_same<T, float>::value ? ncclFloat32 : ncclFloat16;
  bool should_destroy_op =
      scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op);
  memory::Buffer buffer(dev_ctx.GetPlace());
  if (scale && !should_destroy_op) {
    T *new_sendbuff = buffer.Alloc<T>(numel);
S
sneaxiy 已提交
839
    LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
840 841 842
    sendbuff = new_sendbuff;
  }

843 844 845 846 847 848 849
  if (UseReduceScatter) {
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
        sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
  } else {
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
  }
850 851 852 853 854 855 856 857 858

#if NCCL_VERSION_CODE >= 21100
  if (should_destroy_op) {
    VLOG(10) << "ncclRedOpDestroy starts";
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpDestroy(op, comm));
    VLOG(10) << "ncclRedOpDestroy ends";
  }
#endif
}
859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878

template <typename T>
static void NCCLReduceScatterWithScale(
    const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks,
    ncclComm_t comm, gpuStream_t stream,
    const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) {
  NCCLSumWithScaleBase<T, true>(sendbuff, recvbuff, recvcount, nranks, comm,
                                stream, dev_ctx, scale);
}

template <typename T>
static void NCCLAllReduceWithScale(const T *sendbuff, T *recvbuff,
                                   size_t recvcount, size_t nranks,
                                   ncclComm_t comm, gpuStream_t stream,
                                   const platform::CUDADeviceContext &dev_ctx,
                                   const T *scale = nullptr) {
  NCCLSumWithScaleBase<T, false>(sendbuff, recvbuff, recvcount, nranks, comm,
                                 stream, dev_ctx, scale);
}

879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
#endif

template <typename InputIteratorT, typename OutputIteratorT, typename ReduceOpT,
          typename T>
static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out,
                            int num_items, ReduceOpT reduction_op, T init,
                            gpuStream_t stream, memory::Buffer *buffer) {
  void *d_temp_storage = nullptr;
  size_t temp_storage_bytes = 0;
  PADDLE_ENFORCE_GPU_SUCCESS(
      cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out,
                                num_items, reduction_op, init, stream));
  d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
  VLOG(10) << "cub::DeviceReduce::Reduce needs " << temp_storage_bytes
           << " byte(s), ptr = " << d_temp_storage;
  PADDLE_ENFORCE_GPU_SUCCESS(
      cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out,
                                num_items, reduction_op, init, stream));
}

template <typename T>
static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm,
                                  gpuStream_t stream,
                                  memory::Buffer *cub_tmp_buffer) {
  using Iterator =
      cub::TransformInputIterator<float, SquareFunctor<T>, const T *>;
  Iterator iter(grad, SquareFunctor<T>());
  CubDeviceReduce(iter, square_norm, n, cub::Sum(), static_cast<float>(0),
                  stream, cub_tmp_buffer);
}

// square_norm is of length 2 at least
static void GetSquareGradNorm(const float *fp32_grad, int fp32_numel,
                              const platform::float16 *fp16_grad,
                              int fp16_numel, float *square_norm,
                              gpuStream_t stream,
                              memory::Buffer *cub_tmp_buffer) {
  VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel
           << " , fp16_numel = " << fp16_numel;
  if (fp32_numel > 0) {
    GetSquareGradNormImpl(fp32_grad, fp32_numel, square_norm, stream,
                          cub_tmp_buffer);
    VLOG(10) << "FP32 square L2-Norm: "
             << FlattenToString(square_norm, 1, cub_tmp_buffer->GetPlace());
  }

  if (fp16_numel > 0) {
    float *fp16_square_norm = fp32_numel > 0 ? square_norm + 1 : square_norm;
    GetSquareGradNormImpl(fp16_grad, fp16_numel, fp16_square_norm, stream,
                          cub_tmp_buffer);
    VLOG(10) << "FP16 square L2-Norm: "
             << FlattenToString(fp16_square_norm, 1,
                                cub_tmp_buffer->GetPlace());
    if (fp32_numel > 0) {
      AddToCUDAKernel<<<1, 1, 0, stream>>>(fp16_square_norm, square_norm);
      VLOG(10) << "FP32+FP16 square L2-Norm: "
               << FlattenToString(square_norm, 1, cub_tmp_buffer->GetPlace());
    }
  }
  VLOG(10) << "GetSquareGradNorm ends, fp32_numel = " << fp32_numel
           << " , fp16_numel = " << fp16_numel;
}

template <typename T>
std::string NumToString(T x) {
  std::stringstream ss;
  ss << x;
  return ss.str();
}

template <typename T>
static std::string GetMinMaxStr(const T *x, size_t n,
                                const platform::Place &place) {
  PADDLE_ENFORCE_EQ(
      platform::is_gpu_place(place), true,
      platform::errors::InvalidArgument("Only support CUDAPlace currently."));

  auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
      platform::DeviceContextPool::Instance().Get(place));
  auto stream = dev_ctx->stream();

  memory::Buffer ret_buffer(place);
  T *ret = ret_buffer.Alloc<T>(2);

  if (n > 0) {
    memory::Buffer cub_buffer(place);
    CubDeviceReduce(x, ret, n, cub::Min(), std::numeric_limits<T>::max(),
                    stream, &cub_buffer);
    CubDeviceReduce(x, ret + 1, n, cub::Max(), std::numeric_limits<T>::lowest(),
                    stream, &cub_buffer);
    T ret_cpu[2];
#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(&ret_cpu[0], ret, 2 * sizeof(T),
                                              hipMemcpyDeviceToHost, stream));
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(stream));
#else
    PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret_cpu[0], ret, 2 * sizeof(T),
                                               cudaMemcpyDeviceToHost, stream));
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#endif
    return std::string("{\"min\": ") + NumToString(ret_cpu[0]) +
           " , \"max\": " + NumToString(ret_cpu[1]) + "}";
  } else {
    return "{\"min\": null, \"max\": null}";
  }
}

struct VisitDTypeFunctor {
  VisitDTypeFunctor(const framework::Tensor *x, std::string *s)
      : x_(x), s_(s) {}

  template <typename T>
  void apply() const {
    *s_ = GetMinMaxStr<T>(x_->template data<T>(), x_->numel(), x_->place());
  }

 private:
  const framework::Tensor *x_;
  std::string *s_;
};

static std::string GetMinMaxStr(const framework::Tensor *x) {
  if (x == nullptr) return "null";
  if (!x->IsInitialized()) return "not_inited";
  if (!platform::is_gpu_place(x->place())) return "CPUTensor";
  std::string str;
  VisitDTypeFunctor functor(x, &str);
1006
  phi::VisitDataType(x->dtype(), functor);
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
  return str;
}

static void PrintAllMinMaxRange(const framework::ExecutionContext &ctx,
                                bool only_inputs) {
  if (!VLOG_IS_ON(1)) return;
  for (const auto &pair : ctx.GetOp().Inputs()) {
    const auto &key = pair.first;
    const auto tensors = ctx.MultiInput<framework::Tensor>(key);
    size_t n = tensors.size();
    for (size_t i = 0; i < n; ++i) {
      VLOG(1) << "Input(" << key + ")[" << i << "] = " << pair.second[i]
              << " , " << GetMinMaxStr(tensors[i]);
    }
  }

  if (only_inputs) return;
  for (const auto &pair : ctx.GetOp().Outputs()) {
    const auto &key = pair.first;
    const auto tensors = ctx.MultiOutput<framework::Tensor>(key);
    size_t n = tensors.size();
    for (size_t i = 0; i < n; ++i) {
      VLOG(1) << "Output(" << key + ")[" << i << "] = " << pair.second[i]
              << " , " << GetMinMaxStr(tensors[i]);
    }
  }
}

static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel,
                               const platform::float16 *fp16_grad,
                               int fp16_numel, float *nan_inf_flag,
                               gpuStream_t stream,
                               memory::Buffer *cub_tmp_buffer) {
  bool *fp32_has_nan_inf = nullptr;
  bool *fp16_has_nan_inf = nullptr;
  if (fp32_numel > 0) {
    fp32_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1);
    cub::TransformInputIterator<bool, IsNanInfFunctor<float>, const float *>
1045
        iter(fp32_grad, IsNanInfFunctor<float>());
1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070
    CubDeviceReduce(iter, fp32_has_nan_inf, fp32_numel, OrFunctor(), false,
                    stream, cub_tmp_buffer);
  }

  if (fp16_numel > 0) {
    fp16_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1) + 1;
    cub::TransformInputIterator<bool, IsNanInfFunctor<platform::float16>,
                                const platform::float16 *>
        iter(fp16_grad, IsNanInfFunctor<platform::float16>());
    CubDeviceReduce(iter, fp16_has_nan_inf, fp16_numel, OrFunctor(), false,
                    stream, cub_tmp_buffer);
  }

  if (fp32_has_nan_inf && fp16_has_nan_inf) {
    SetNanInfValueCUDAKernelTwoFlag<<<1, 1, 0, stream>>>(
        fp32_has_nan_inf, fp16_has_nan_inf, nan_inf_flag);
  } else if (fp32_has_nan_inf) {
    SetNanInfValueCUDAKernelOneFlag<<<1, 1, 0, stream>>>(fp32_has_nan_inf,
                                                         nan_inf_flag);
  } else {
    SetNanInfValueCUDAKernelOneFlag<<<1, 1, 0, stream>>>(fp16_has_nan_inf,
                                                         nan_inf_flag);
  }
}

1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111
template <typename T1, typename T2, typename T3, int VecSize>
static __global__ void ElementwiseAddWithCastCUDAKernel(const T1 *x,
                                                        const T2 *y, T3 *z,
                                                        int n) {
  static_assert(sizeof(T1) <= sizeof(T2),
                "sizeof(T1) must be smaller than sizeof(T2).");
  using MT = MasterT<T2>;

  int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  int stride = (blockDim.x * gridDim.x) * VecSize;
  for (; i + VecSize <= n; i += stride) {
    phi::AlignedVector<T1, VecSize> x_vec;
    phi::AlignedVector<T2, VecSize> y_vec;
    phi::AlignedVector<T3, VecSize> z_vec;
    phi::Load(x + i, &x_vec);
    phi::Load(y + i, &y_vec);
#pragma unroll
    for (int j = 0; j < VecSize; ++j) {
      auto x_tmp = static_cast<MT>(x_vec[j]);
      auto y_tmp = static_cast<MT>(y_vec[j]);
      z_vec[j] = static_cast<T3>(x_tmp + y_tmp);
    }
    phi::Store(z_vec, z + i);
  }

  for (; i < n; ++i) {
    auto x_tmp = static_cast<MT>(x[i]);
    auto y_tmp = static_cast<MT>(y[i]);
    z[i] = static_cast<T3>(x_tmp + y_tmp);
  }
}

template <typename T1, typename T2, typename T3>
static void LaunchElementwiseAddWithCastKernel(
    const platform::CUDADeviceContext &dev_ctx, const T1 *x, const T2 *y, T3 *z,
    int n, gpuStream_t stream) {
  int vec_size =
      std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)),
               GetChunkedVecSize(z, 0));
  auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);

1112 1113 1114 1115 1116
#define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL                             \
  do {                                                                         \
    ElementwiseAddWithCastCUDAKernel<T1, T2, T3, kVecSize>                     \
        <<<config.block_per_grid, config.thread_per_block, 0, stream>>>(x, y,  \
                                                                        z, n); \
1117 1118 1119 1120 1121 1122
  } while (0)

  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL);
#undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL
}

1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
template <typename T>
class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto stream = dev_ctx.stream();
    auto place = dev_ctx.GetPlace();

1133 1134 1135
    auto *found_inf_t = ctx.Output<framework::Tensor>("FoundInf");
    found_inf_t->Resize({1});

1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179
    // Step 1: Get fp16 param and grad tensors
    int64_t fp16_numel;
    auto *fp16_param = GetSameInOutTensorPtr<platform::float16, true>(
        ctx, place, "FP16FusedParam", "FP16FusedParamOut", &fp16_numel);
    bool has_fp16_param = (fp16_numel > 0);
    const platform::float16 *fp16_grad = nullptr;
    if (has_fp16_param) {
      fp16_grad = GetInputTensorPtr<platform::float16>(ctx, "FP16FusedGrad");
    } else {
      fp16_param = nullptr;
    }

    // Step 2: Get fp32 param and grad tensors
    int64_t fp32_numel = 0;
    auto *fp32_param = GetSameInOutTensorPtr<float, true>(
        ctx, place, "FP32FusedParam", "FP32FusedParamOut", &fp32_numel);
    PADDLE_ENFORCE_GE(fp32_numel, fp16_numel,
                      platform::errors::InvalidArgument(
                          "The element number in FP32FusedParam should be not "
                          "less than FP16FusedParam."));

    fp32_numel -= fp16_numel;  // the FP32FusedParam contains fp32 param and
                               // fp16 master weight
    bool has_fp32_param = (fp32_numel > 0);
    const float *fp32_grad = nullptr;
    if (has_fp32_param) {
      fp32_grad = GetInputTensorPtr<float>(ctx, "FP32FusedGrad");
    } else {
      PADDLE_ENFORCE_EQ(
          has_fp16_param, true,
          platform::errors::InvalidArgument(
              "Either FP32FusedGrad or FP16FusedGrad cannot be NULL."));
    }

    auto numel = fp32_numel + fp16_numel;
    VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel
            << " , fp16_numel = " << fp16_numel;

    // The NVIDIA cub library does not support number > INT32_MAX
    PADDLE_ENFORCE_LE(numel, std::numeric_limits<int>::max(),
                      platform::errors::Unimplemented(
                          "Too many parameter number. Only <= %d is supported.",
                          std::numeric_limits<int>::max()));

1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220
    auto acc_steps = ctx.Attr<int>("acc_steps");
    PADDLE_ENFORCE_GE(
        acc_steps, 1,
        platform::errors::InvalidArgument(
            "The gradient accumulation steps should be not less than 1."));
    if (acc_steps > 1) {
      auto *step_t = ctx.Output<framework::Tensor>("AccStep");
      PADDLE_ENFORCE_NOT_NULL(
          step_t,
          platform::errors::InvalidArgument(
              "Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
      bool is_initialized = step_t->IsInitialized();
      int64_t *step_ptr;
      if (is_initialized) {
        step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
        ++(*step_ptr);
      } else {
        step_t->Resize({1});
        step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
        *step_ptr = 1;
      }
      int64_t rounded_step = (*step_ptr) % acc_steps;

      float *fp32_acc_grad = nullptr;
      if (has_fp32_param) {
        auto *fp32_acc_grad_t =
            ctx.Output<framework::Tensor>("FP32AccFusedGrad");
        PADDLE_ENFORCE_NOT_NULL(
            fp32_acc_grad_t, platform::errors::InvalidArgument(
                                 "Output(FP32AccFusedGrad) cannot be nullptr "
                                 "when Attr(acc_steps) > 1."));
        if (!fp32_acc_grad_t->IsInitialized()) {
          fp32_acc_grad_t->Resize({static_cast<int64_t>(fp32_numel)});
          fp32_acc_grad = fp32_acc_grad_t->mutable_data<float>(place);
        } else {
          fp32_acc_grad = fp32_acc_grad_t->data<float>();
        }
      }

      platform::float16 *fp16_acc_grad = nullptr;
      float *master_acc_grad = nullptr;
1221
      bool use_master_acc_grad = false;
1222
      if (has_fp16_param) {
1223
        use_master_acc_grad = ctx.Attr<bool>("use_master_acc_grad");
1224 1225 1226 1227 1228 1229 1230
        auto *fp16_acc_grad_t =
            ctx.Output<framework::Tensor>("FP16AccFusedGrad");
        PADDLE_ENFORCE_NOT_NULL(
            fp16_acc_grad_t, platform::errors::InvalidArgument(
                                 "Output(FP16AccFusedGrad) cannot be nullptr "
                                 "when Attr(acc_steps) > 1."));
        if (!fp16_acc_grad_t->IsInitialized()) {
1231 1232 1233
          auto acc_grad_size =
              use_master_acc_grad ? (3 * fp16_numel) : fp16_numel;
          fp16_acc_grad_t->Resize({static_cast<int64_t>(acc_grad_size)});
1234 1235 1236 1237 1238
          fp16_acc_grad =
              fp16_acc_grad_t->mutable_data<platform::float16>(place);
        } else {
          fp16_acc_grad = fp16_acc_grad_t->data<platform::float16>();
        }
1239 1240 1241 1242
        if (use_master_acc_grad) {
          master_acc_grad =
              reinterpret_cast<float *>(fp16_acc_grad + fp16_numel);
        }
1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256
      }

      // Inplace addto
      if (has_fp32_param) {
        if (rounded_step == 1) {
          memory::Copy(place, fp32_acc_grad, place, fp32_grad,
                       fp32_numel * sizeof(float), stream);
        } else {
          LaunchElementwiseAddWithCastKernel(dev_ctx, fp32_grad, fp32_acc_grad,
                                             fp32_acc_grad, fp32_numel, stream);
        }
      }

      if (has_fp16_param) {
1257 1258
        if (acc_steps == 2 || !use_master_acc_grad) {
          if (rounded_step != 1) {
1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308
            LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_acc_grad,
                                               fp16_grad, fp16_acc_grad,
                                               fp16_numel, stream);
          } else {
            memory::Copy(place, fp16_acc_grad, place, fp16_grad,
                         fp16_numel * sizeof(platform::float16), stream);
          }
        } else {  // acc_steps >= 3
          if (rounded_step == 0) {
            LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_grad,
                                               master_acc_grad, fp16_acc_grad,
                                               fp16_numel, stream);
          } else if (rounded_step == 1) {
            memory::Copy(place, fp16_acc_grad, place, fp16_grad,
                         fp16_numel * sizeof(platform::float16), stream);
          } else if (rounded_step == 2) {
            LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_grad,
                                               fp16_acc_grad, master_acc_grad,
                                               fp16_numel, stream);
          } else {
            LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_grad,
                                               master_acc_grad, master_acc_grad,
                                               fp16_numel, stream);
          }
        }
      }

      auto *stop_update_t = ctx.Output<framework::Tensor>("StopUpdate");
      stop_update_t->Resize({1});
      auto *stop_update =
          stop_update_t->mutable_data<bool>(platform::CPUPlace());

      auto *found_inf_cpu =
          found_inf_t->mutable_data<bool>(platform::CPUPlace());

      if (rounded_step != 0) {
        *stop_update = true;
        auto *found_inf_cpu =
            found_inf_t->mutable_data<bool>(platform::CPUPlace());
        *found_inf_cpu = false;
        return;
      } else {
        // swap pointer
        fp32_grad = fp32_acc_grad;
        fp16_grad = fp16_acc_grad;
        *stop_update = false;
        found_inf_t->clear();
      }
    }

1309
    // Step 3: Get ParamInfo
1310 1311 1312 1313
    const auto *param_info_tensor = GetInputTensorPtr<int>(ctx, "ParamInfo");
    auto fp32_local_start_idx = param_info_tensor[0];
    auto fp32_local_param_num = param_info_tensor[1];
    auto fp32_global_param_num = param_info_tensor[2];
1314 1315 1316 1317 1318
    auto fp32_weight_decay_end_idx = param_info_tensor[3];
    auto fp16_local_start_idx = param_info_tensor[4];
    auto fp16_local_param_num = param_info_tensor[5];
    auto fp16_global_param_num = param_info_tensor[6];
    auto fp16_weight_decay_end_idx = param_info_tensor[7];
1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335

    auto local_param_num = fp32_local_param_num + fp16_local_param_num;
    auto param_num = fp32_global_param_num + fp16_global_param_num;
    PADDLE_ENFORCE_LE(local_param_num, param_num,
                      platform::errors::InvalidArgument(
                          "The local parameter number should not exceed the "
                          "global parameter number."));
    VLOG(1) << "local_param_num = " << local_param_num
            << " , global_param_num = " << param_num
            << " , fp32_local_start_idx = " << fp32_local_start_idx
            << " , fp32_local_param_num = " << fp32_local_param_num
            << " , fp32_global_param_num = " << fp32_global_param_num
            << " , fp16_local_start_idx = " << fp16_local_start_idx
            << " , fp16_local_param_num = " << fp16_local_param_num
            << " , fp16_global_param_num = " << fp16_global_param_num;

    // Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow,
1336
    // GlobalScale
1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348
    const auto *global_scale = GetInputTensorPtr<float>(ctx, "GlobalScale");
    const auto *lr = GetInputTensorPtr<float>(ctx, "LearningRate");
    int64_t partial_numel = 0;
    auto *moment1 = GetSameInOutTensorPtr<float>(ctx, place, "Moment1",
                                                 "Moment1Out", &partial_numel);

    PADDLE_ENFORCE_EQ(numel % partial_numel, 0,
                      platform::errors::InvalidArgument(
                          "The total parameter number %d should be divided "
                          "exactly by the element number %d of Moment1.",
                          numel, partial_numel));

1349 1350 1351
    // The num_devices means the number of devices that shard a complete set
    // of all parameters. It may be num_devices < nranks or num_devices ==
    // nranks.
1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375
    int64_t num_devices = numel / partial_numel;
    VLOG(1) << "num_devices = " << num_devices
            << " , partial_numel = " << partial_numel;

    PADDLE_ENFORCE_EQ(fp32_numel % num_devices, 0,
                      platform::errors::InvalidArgument(
                          "The fp32 parameter number %d should be divided "
                          "exactly by the device number %d.",
                          fp32_numel, num_devices));
    PADDLE_ENFORCE_EQ(fp16_numel % num_devices, 0,
                      platform::errors::InvalidArgument(
                          "The fp16 parameter number %d should be divided "
                          "exactly by the device number %d.",
                          fp16_numel, num_devices));

    auto *moment2 =
        GetSameInOutTensorPtr<float>(ctx, place, "Moment2", "Moment2Out");
    auto *beta1pow =
        GetSameInOutTensorPtr<float>(ctx, place, "Beta1Pow", "Beta1PowOut");
    auto *beta2pow =
        GetSameInOutTensorPtr<float>(ctx, place, "Beta2Pow", "Beta2PowOut");

    auto *found_inf = found_inf_t->mutable_data<bool>(place);

1376 1377
    // Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
    // max_grad_norm, ring_id,
1378
    // use_master_param_norm, is_grad_scaled_by_nranks
1379
    auto weight_decay = ctx.Attr<float>("weight_decay");
1380 1381 1382 1383 1384
    auto beta1 = ctx.Attr<float>("beta1");
    auto beta2 = ctx.Attr<float>("beta2");
    auto epsilon = ctx.Attr<float>("epsilon");
    auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm");
    auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce");
1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395
    auto nranks = ctx.Attr<int64_t>("nranks");
    PADDLE_ENFORCE_GE(nranks, num_devices,
                      phi::errors::InvalidArgument(
                          "The nranks must be not less than num_devices."));
    PADDLE_ENFORCE_EQ(
        nranks % num_devices, 0,
        phi::errors::InvalidArgument(
            "The nranks must be exactly divided by num_devices."));
    bool local_shard = (nranks > num_devices);

    const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id");
1396 1397 1398 1399 1400
    auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
    auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
    VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
             << " , clip_after_allreduce = " << clip_after_allreduce
             << " , use_master_param_norm = " << use_master_param_norm
1401 1402
             << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
             << " , local_shard = " << local_shard;
1403 1404

    // Step 6: allreduce + global norm gradient clip
1405 1406 1407
    int64_t global_rank = 0, local_rank = 0;
    ncclComm_t global_comm = nullptr, local_comm = 0;
    if (nranks > 1) {
1408
      auto *nccl_comm_handle =
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421
          platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
      global_comm = nccl_comm_handle->comm();
      global_rank = nccl_comm_handle->rank();

      if (local_shard) {
        auto *local_nccl_comm_handle =
            platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
        local_comm = local_nccl_comm_handle->comm();
        local_rank = local_nccl_comm_handle->rank();
      } else {
        local_comm = global_comm;
        local_rank = global_rank;
      }
1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432
    }

    memory::Buffer grad_norm_square_buffer(place);
    auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
    memory::Buffer cub_tmp_buffer(place);

    memory::Buffer sum_grad_buffer(place);
    float *fp32_sum_grad;
    platform::float16 *fp16_sum_grad;
    auto fp32_numel_each_device = fp32_numel / num_devices;
    auto fp16_numel_each_device = fp16_numel / num_devices;
1433 1434 1435 1436 1437 1438 1439 1440 1441
    if (local_shard) {
      auto ptr = sum_grad_buffer.Alloc<uint8_t>(
          fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16));
      fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
      fp16_sum_grad = has_fp16_param ? reinterpret_cast<platform::float16 *>(
                                           ptr + fp32_numel * sizeof(float))
                                     : nullptr;
    } else if (nranks > 1 ||
               (max_global_grad_norm > 0 && !clip_after_allreduce)) {
1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462
      auto ptr = sum_grad_buffer.Alloc<uint8_t>(
          fp32_numel_each_device * sizeof(float) +
          fp16_numel_each_device * sizeof(platform::float16));
      fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
      fp16_sum_grad = has_fp16_param
                          ? reinterpret_cast<platform::float16 *>(
                                ptr + fp32_numel_each_device * sizeof(float))
                          : nullptr;
    } else {
      // NOTE: The const_cast here is not important. The fp32_sum_grad and
      // fp16_sum_grad would not be changed when num_devices == 1
      // But if I do not perform const_cast here, there would be more
      // if-else codes (num_devices > 1) when I write the following code.
      // So I prefer to use const_cast to unify the following code to reduce
      // the if-else codes.
      fp32_sum_grad = const_cast<float *>(fp32_grad);
      fp16_sum_grad = const_cast<platform::float16 *>(fp16_grad);
    }

    float rescale_grad = 1.0f;
    if (!is_grad_scaled_by_nranks) {
1463
      rescale_grad /= nranks;
1464 1465 1466 1467 1468
    }

    if (max_global_grad_norm > 0) {
      if (clip_after_allreduce) {
        // (1) ReduceScater first
1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483
        if (local_shard) {
          NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
                                 global_comm, stream, dev_ctx);
          NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
                                 global_comm, stream, dev_ctx);
          fp32_sum_grad += (local_rank * fp32_numel_each_device);
          fp16_sum_grad += (local_rank * fp16_numel_each_device);
        } else {
          NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
                                     fp32_numel_each_device, nranks,
                                     global_comm, stream, dev_ctx);
          NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
                                     fp16_numel_each_device, nranks,
                                     global_comm, stream, dev_ctx);
        }
1484 1485 1486 1487 1488 1489 1490 1491 1492
        // (2) Calculate the global grad norm
        GetSquareGradNorm(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad,
                          fp16_numel_each_device, fp32_square_grad_norm, stream,
                          &cub_tmp_buffer);
        VLOG(1) << "Grad square norm before all reduce: "
                << FlattenToString(fp32_square_grad_norm, 1, place);
        if (num_devices > 1) {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
              fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
1493
              ncclSum, local_comm, stream));
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519
        }
        VLOG(1) << "Grad square norm after all reduce: "
                << FlattenToString(fp32_square_grad_norm, 1, place);
      } else {
        // (1) Calculate the local grad norm
        GetSquareGradNorm(fp32_grad, fp32_numel, fp16_grad, fp16_numel,
                          fp32_square_grad_norm, stream, &cub_tmp_buffer);
        VLOG(1) << "Grad square norm before all reduce: "
                << FlattenToString(fp32_square_grad_norm, 1, place);
        // (2) Calculate the gradient clip scale
        float *fp32_scale = nullptr;
        platform::float16 *fp16_scale = nullptr;
        if (has_fp32_param && has_fp16_param) {
          auto *ptr = cub_tmp_buffer.Alloc<uint8_t>(sizeof(float) +
                                                    sizeof(platform::float16));
          fp32_scale = reinterpret_cast<float *>(ptr);
          fp16_scale =
              reinterpret_cast<platform::float16 *>(ptr + sizeof(float));
        } else if (has_fp32_param) {
          fp32_scale = cub_tmp_buffer.Alloc<float>(1);
        } else {
          fp16_scale = cub_tmp_buffer.Alloc<platform::float16>(1);
        }

        float clip_scale = 1.0f;
        if (is_grad_scaled_by_nranks) {
1520
          clip_scale *= nranks;
1521
        }
1522 1523 1524 1525
        CalcGradNormClipBeforeAllReduceScale<float, platform::float16>
            <<<1, 1, 0, stream>>>(global_scale, max_global_grad_norm,
                                  fp32_square_grad_norm, fp32_scale, fp16_scale,
                                  clip_scale);
1526 1527 1528 1529 1530
        if (fp32_scale) {
          VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place);
        } else {
          VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place);
        }
1531
        if (nranks > 1) {
1532 1533
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
              fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
1534
              ncclSum, global_comm, stream));
1535 1536
        }
        // (3) Do ReduceScatter with scale
1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551
        if (local_shard) {
          NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
                                 global_comm, stream, dev_ctx, fp32_scale);
          NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
                                 global_comm, stream, dev_ctx, fp16_scale);
          fp32_sum_grad += (local_rank * fp32_numel_each_device);
          fp16_sum_grad += (local_rank * fp16_numel_each_device);
        } else {
          NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
                                     fp32_numel_each_device, nranks,
                                     global_comm, stream, dev_ctx, fp32_scale);
          NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
                                     fp16_numel_each_device, nranks,
                                     global_comm, stream, dev_ctx, fp16_scale);
        }
1552 1553 1554 1555 1556
        // (4) mark max_global_grad_norm as 0, meaning that clip has been
        // already performed
        max_global_grad_norm = 0;
      }
    } else {
1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571
      if (local_shard) {
        NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks,
                               global_comm, stream, dev_ctx);
        NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks,
                               global_comm, stream, dev_ctx);
        fp32_sum_grad += (local_rank * fp32_numel_each_device);
        fp16_sum_grad += (local_rank * fp16_numel_each_device);
      } else {
        NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad,
                                   fp32_numel_each_device, num_devices,
                                   global_comm, stream, dev_ctx);
        NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad,
                                   fp16_numel_each_device, num_devices,
                                   global_comm, stream, dev_ctx);
      }
1572 1573 1574 1575 1576 1577
      CheckHasNanInfGrad(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad,
                         fp16_numel_each_device, fp32_square_grad_norm, stream,
                         &cub_tmp_buffer);
      if (num_devices > 1) {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
            fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32,
1578
            ncclSum, local_comm, stream));
1579 1580 1581 1582 1583 1584
      }
      max_global_grad_norm = 0;
    }
    VLOG(10) << "ReduceScatter done";

    // Step 7: update the moment1, moment2. Calcuate the trust_ratio_div
1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595
    auto *fused_offsets_t = ctx.Input<framework::Tensor>("FusedParamOffsets");
    auto *fused_offsets = fused_offsets_t->data<int>();
    auto *fp32_partial_fused_offsets_t =
        ctx.Input<framework::Tensor>("FP32ShardFusedParamOffsets");
    const auto *fp32_partial_fused_offsets =
        fp32_partial_fused_offsets_t->data<int>();
    auto *fp16_partial_fused_offsets_t =
        ctx.Input<framework::Tensor>("FP16ShardFusedParamOffsets");
    const auto *fp16_partial_fused_offsets =
        fp16_partial_fused_offsets_t->data<int>();

1596 1597
    auto *step = ctx.Output<framework::Tensor>("Step")->data<int64_t>();

1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609
    VLOG(1) << "FusedParamOffsets: "
            << FlattenToString(fused_offsets, fused_offsets_t->numel(),
                               fused_offsets_t->place());
    VLOG(1) << "FP32ShardFusedParamOffsets: "
            << FlattenToString(fp32_partial_fused_offsets,
                               fp32_partial_fused_offsets_t->numel(),
                               fp32_partial_fused_offsets_t->place());
    VLOG(1) << "FP16ShardFusedParamOffsets: "
            << FlattenToString(fp16_partial_fused_offsets,
                               fp16_partial_fused_offsets_t->numel(),
                               fp16_partial_fused_offsets_t->place());

1610 1611
    memory::Buffer trust_ratio_div_buffer(place);
    auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
1612 1613
    auto fp32_offset = local_rank * fp32_numel_each_device;
    auto fp16_offset = local_rank * fp16_numel_each_device;
1614 1615
    if (has_fp32_param) {
      VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
1616 1617
      MultiTensorUpdateLambMomentAndTrustRatioDiv(
          dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num,
1618
          fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm,
1619
          global_scale, beta1pow, beta2pow, moment1, moment2, trust_ratio_div,
1620 1621
          found_inf, step, weight_decay, fp32_weight_decay_end_idx, beta1,
          beta2, epsilon, max_global_grad_norm, rescale_grad);
1622 1623 1624 1625 1626 1627
      VLOG(10) << "Update FP32 Moment and TrustRatioDiv done";
    }
    float *master_param = nullptr;
    if (has_fp16_param) {
      master_param = fp32_param + fp32_numel;
      VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
1628
      auto tmp_found_inf = has_fp32_param ? nullptr : found_inf;
1629
      auto tmp_step = has_fp32_param ? nullptr : step;
1630 1631
      MultiTensorUpdateLambMomentAndTrustRatioDiv(
          dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num,
1632
          master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm,
1633
          global_scale, beta1pow, beta2pow, moment1 + fp32_numel_each_device,
1634
          moment2 + fp32_numel_each_device,
1635 1636
          trust_ratio_div + fp32_numel_each_device, tmp_found_inf, tmp_step,
          weight_decay, fp16_weight_decay_end_idx, beta1, beta2, epsilon,
1637
          max_global_grad_norm, rescale_grad);
1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654
      VLOG(10) << "Update FP16 Moment and TrustRatioDiv done";
    }

    VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha";

    // Step 8: calculate L2-Norm square of parameter and trust_ratio_div
    memory::Buffer square_norm_buffer(place);
    auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num);
    auto *trust_ratio_div_square_norm = param_square_norm + param_num;
    if (num_devices > 1) {
      if (use_master_param_norm) {
        FillZeroWithPtr(param_square_norm + fp32_global_param_num,
                        2 * param_num - fp32_global_param_num, stream);
      } else {
        FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
      }
    }
1655 1656
    MultiTensorL2Norm(place, stream, fp32_param, fused_offsets,
                      fp32_global_param_num, param_square_norm);
1657
    if (use_master_param_norm) {
1658 1659 1660
      MultiTensorL2Norm(place, stream, master_param + fp16_offset,
                        fp16_partial_fused_offsets, fp16_local_param_num,
                        param_square_norm + fp16_local_start_idx);
1661
    } else {
1662 1663 1664 1665 1666 1667
      MultiTensorL2Norm(place, stream,
                        fp16_param + fused_offsets[fp16_local_start_idx] -
                            fused_offsets[fp32_global_param_num],
                        fused_offsets + fp16_local_start_idx,
                        fp16_local_param_num,
                        param_square_norm + fp16_local_start_idx);
1668 1669
    }

1670 1671 1672 1673 1674 1675
    MultiTensorL2Norm(place, stream, trust_ratio_div,
                      fp32_partial_fused_offsets, fp32_local_param_num,
                      trust_ratio_div_square_norm + fp32_local_start_idx);
    MultiTensorL2Norm(place, stream, trust_ratio_div + fp32_numel_each_device,
                      fp16_partial_fused_offsets, fp16_local_param_num,
                      trust_ratio_div_square_norm + fp16_local_start_idx);
1676 1677 1678 1679 1680 1681 1682 1683

    VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
            << FlattenToString(trust_ratio_div_square_norm, param_num, place);
    if (num_devices > 1) {
      if (use_master_param_norm) {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
            param_square_norm + fp32_global_param_num,
            param_square_norm + fp32_global_param_num,
1684 1685
            2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum,
            local_comm, stream));
1686 1687 1688
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
            trust_ratio_div_square_norm, trust_ratio_div_square_norm, param_num,
1689
            ncclFloat32, ncclSum, local_comm, stream));
1690 1691 1692 1693 1694 1695 1696 1697 1698 1699
      }
      VLOG(10) << "ncclAllReduce done";
    }

    LogParamAndTrustRatioDivSquareNorm<1>(ctx, param_square_norm,
                                          trust_ratio_div_square_norm);
    VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done";

    // Step 9: update parameter, beta1pow, beta2pow. All gather parameters.
    if (has_fp32_param) {
1700 1701 1702 1703 1704
      MultiTensorUpdateLambParamAndBetaPows<float>(
          dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num,
          trust_ratio_div, lr, param_square_norm + fp32_local_start_idx,
          trust_ratio_div_square_norm + fp32_local_start_idx, found_inf,
          fp32_param + fp32_offset, nullptr, beta1pow, beta2pow, beta1, beta2);
1705 1706 1707 1708
      if (num_devices > 1) {
        // ncclAllGather
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
            fp32_param + fp32_offset, fp32_param, fp32_numel_each_device,
1709
            ncclFloat32, local_comm, stream));
1710
      }
1711 1712 1713

      beta1pow = nullptr;
      beta2pow = nullptr;
1714 1715
    }
    if (has_fp16_param) {
1716 1717 1718 1719 1720 1721 1722
      MultiTensorUpdateLambParamAndBetaPows<platform::float16>(
          dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num,
          trust_ratio_div + fp32_numel_each_device, lr,
          param_square_norm + fp16_local_start_idx,
          trust_ratio_div_square_norm + fp16_local_start_idx, found_inf,
          fp16_param + fp16_offset, master_param + fp16_offset, beta1pow,
          beta2pow, beta1, beta2);
1723 1724 1725 1726
      if (num_devices > 1) {
        // ncclAllGather
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
            fp16_param + fp16_offset, fp16_param, fp16_numel_each_device,
1727
            ncclFloat16, local_comm, stream));
1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748
      }
    }
    VLOG(10) << "Update Param done";

    VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm);
#else
    PADDLE_THROW(platform::errors::Unimplemented(
        "distributed_fused_lamb op should be used with NCCL/RCCL."));
#endif
  }
};

}  // namespace operators
}  // namespace paddle

namespace plat = paddle::platform;
namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
    distributed_fused_lamb,
    ops::DistributedFusedLambOpKernel<plat::CUDADeviceContext, float>);