distributed_fused_lamb_op.cu 87.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>
16

17
#include "paddle/fluid/memory/buffer.h"
18
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
19 20
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
21
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
22 23 24
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/string/string_helper.h"
25
#include "paddle/phi/core/utils/data_type.h"
26
#include "paddle/phi/kernels/funcs/aligned_vector.h"
T
Thomas Young 已提交
27
#include "paddle/phi/kernels/funcs/tensor_to_string.h"
28 29 30 31 32 33 34 35

#ifdef __NVCC__
#include "cub/cub.cuh"
#include "math.h"  // NOLINT
#endif

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
36

37 38 39 40 41 42 43 44 45
#include "math.h"  // NOLINT
namespace cub = hipcub;
#endif

namespace paddle {
namespace operators {

template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type;
T
Thomas Young 已提交
46 47
using phi::funcs::FlattenToString;
using phi::funcs::ToVector;
48

49 50 51 52 53 54 55 56 57 58 59 60
template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
  static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}

template <typename T, int BlockDim, int VecSize>
struct L2NormFunctor {
61 62 63 64 65 66 67
  DEVICE void operator()(int tensor_id,
                         int chunk_id,
                         int offset,
                         int size,
                         const T *x,
                         MasterT<T> *y,
                         int max_chunk_num) const {
68 69 70 71 72 73 74 75 76 77
    using MT = MasterT<T>;
    const T *ptr = x + offset;

    using BlockReduce = cub::BlockReduce<MT, BlockDim>;
    __shared__ typename BlockReduce::TempStorage storage;

    MT square_sum = static_cast<MT>(0);
    int i;
    for (i = threadIdx.x * VecSize; i + VecSize <= size;
         i += (BlockDim * VecSize)) {
78 79
      phi::AlignedVector<T, VecSize> tmp_vec;
      phi::Load(ptr + i, &tmp_vec);
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
#pragma unroll
      for (int j = 0; j < VecSize; ++j) {
        auto tmp = static_cast<MT>(tmp_vec[j]);
        square_sum += (tmp * tmp);
      }
    }

    for (; i < size; ++i) {
      auto tmp = static_cast<MT>(ptr[i]);
      square_sum += (tmp * tmp);
    }

    square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum());
    if (threadIdx.x == 0) {
      y[tensor_id * max_chunk_num + chunk_id] = square_sum;
    }
  }
};

99
template <typename InT, typename OutT, int BlockDim>
100 101 102 103 104 105 106 107 108 109 110 111
static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
    const InT *x, OutT *y, int max_chunk_num) {
  int tensor_id = blockIdx.x;
  x += (tensor_id * max_chunk_num);
  using BlockReduce = cub::BlockReduce<InT, BlockDim>;
  __shared__ typename BlockReduce::TempStorage storage;
  InT sum = static_cast<InT>(0);
  for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) {
    sum += x[i];
  }
  sum = BlockReduce(storage).Reduce(sum, cub::Sum());
  if (threadIdx.x == 0) {
112
    y[blockIdx.x] = static_cast<OutT>(sum);
113 114 115 116 117 118 119 120 121 122
  }
}

template <typename T>
static int GetChunkedVecSize(const T *ptr, int chunk_size) {
  static_assert(!std::is_same<T, void>::value, "T cannot be void.");

  constexpr int max_load_bits = 128;
  int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
  auto address = reinterpret_cast<uintptr_t>(ptr);
123 124 125
  constexpr int vec8 = alignof(phi::AlignedVector<T, 8>);
  constexpr int vec4 = alignof(phi::AlignedVector<T, 4>);
  constexpr int vec2 = alignof(phi::AlignedVector<T, 2>);
126
  chunk_size *= sizeof(T);
127 128 129 130 131 132 133 134 135 136 137
  if (address % vec8 == 0 && chunk_size % vec8 == 0) {
    return std::min(8, valid_vec_size);
  } else if (address % vec4 == 0 && chunk_size % vec4 == 0) {
    return std::min(4, valid_vec_size);
  } else if (address % vec2 == 0 && chunk_size % vec2 == 0) {
    return std::min(2, valid_vec_size);
  } else {
    return 1;
  }
}

138 139 140 141 142
#define PD_VEC_LAUNCH_KERNEL_CASE(__vec_size, ...) \
  case __vec_size: {                               \
    constexpr int kVecSize = __vec_size;           \
    __VA_ARGS__;                                   \
    break;                                         \
143 144
  }

145 146 147 148 149 150 151 152
#define PD_VEC_LAUNCH_KERNEL(__vec_size, ...)    \
  do {                                           \
    switch (__vec_size) {                        \
      PD_VEC_LAUNCH_KERNEL_CASE(8, __VA_ARGS__); \
      PD_VEC_LAUNCH_KERNEL_CASE(4, __VA_ARGS__); \
      PD_VEC_LAUNCH_KERNEL_CASE(2, __VA_ARGS__); \
      PD_VEC_LAUNCH_KERNEL_CASE(1, __VA_ARGS__); \
    }                                            \
153 154 155
  } while (0)

// TODO(zengjinle): which chunk_size is better?
156 157 158
template <typename InT,
          typename OutT,
          int MaxTensorNumPerLaunch = 160,
159
          int MaxChunkNumPerLaunch = 780>
160
static void MultiTensorL2Norm(const platform::CUDAPlace &place,
161 162 163 164 165
                              gpuStream_t stream,
                              const InT *x,
                              const int *offsets,
                              int n,
                              OutT *y,
166 167 168 169 170
                              int chunk_size = 65536) {
  if (n <= 0) return;

  constexpr int kNumTensor = MaxTensorNumPerLaunch;
  constexpr int kNumChunk = MaxChunkNumPerLaunch;
171 172 173
#ifdef PADDLE_WITH_HIP
  constexpr int kBlockDim = 256;
#else
174
  constexpr int kBlockDim = 512;
175
#endif
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197

  int max_chunk_num = -1;
  int vec_size = 8;
  int total_chunk_num = 0;
  for (int i = 0; i < n; ++i) {
    vec_size = std::min(
        vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size));
    int length = offsets[i + 1] - offsets[i];
    auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size;
    max_chunk_num = std::max(max_chunk_num, tmp_chunk_num);
    total_chunk_num += tmp_chunk_num;
  }

  VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num
          << " , total_chunk_num = " << total_chunk_num
          << " , tensor_num = " << n;

  using MT = MasterT<InT>;
  memory::Buffer tmp_out(place);
  auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
  FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);

198 199 200 201 202 203 204 205 206 207 208 209 210 211
#define PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL                   \
  do {                                                                \
    using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>;         \
    VLOG(10) << __func__ << " " << typeid(InT).name()                 \
             << " VecSize = " << kVecSize;                            \
    MultiTensorApply<FunctorT, kNumTensor, kNumChunk>(FunctorT(),     \
                                                      stream,         \
                                                      offsets,        \
                                                      n,              \
                                                      chunk_size,     \
                                                      kBlockDim,      \
                                                      x,              \
                                                      tmp_out_ptr,    \
                                                      max_chunk_num); \
212 213
  } while (0)

214 215
  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL);
#undef PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL
216

217 218
  MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim>
      <<<n, kBlockDim, 0, stream>>>(tmp_out_ptr, y, max_chunk_num);
219 220
}

221 222
template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm(
223 224
    const framework::ExecutionContext &ctx,
    const float *param_square_norm,
225 226 227 228 229 230
    const float *trust_ratio_div_square_norm) {
  if (!VLOG_IS_ON(LogLevel)) return;

  auto tensors = ctx.MultiInput<framework::Tensor>("Param");
  if (tensors.empty()) return;

231 232
  const auto *order = ctx.Input<framework::Tensor>("ParamOrder")->data<int>();

233 234 235 236 237 238 239
  size_t n = tensors.size();
  auto place = tensors[0]->place();

  auto pn_vec = ToVector(param_square_norm, n, place);
  auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place);

  const auto &names = ctx.GetOp().Inputs("Param");
240 241
  for (size_t i = 0; i < n; ++i) {
    auto idx = order[i];
242 243 244 245 246
    VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx]
                   << " pn = " << pn_vec[i] << " , tn = " << tn_vec[i];
  }
}

L
Leo Chen 已提交
247
static bool IsFinite(const phi::GPUContext &dev_ctx, const float *ptr) {
248 249 250
  auto stream = dev_ctx.stream();
  float cpu_value;
#ifdef PADDLE_WITH_HIP
251 252
  PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(
      &cpu_value, ptr, sizeof(float), hipMemcpyDeviceToHost, stream));
253 254
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(stream));
#else
255 256
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
      &cpu_value, ptr, sizeof(float), cudaMemcpyDeviceToHost, stream));
257 258 259 260 261 262 263 264 265 266 267
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#endif
  LOG(INFO) << "NAN_INF indicator value: " << cpu_value;
  return isfinite(cpu_value);
}

template <typename T>
static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx,
                                  const char *in_name,
                                  int64_t *numel = nullptr) {
  const auto *in_tensor = ctx.Input<framework::Tensor>(in_name);
268 269 270
  PADDLE_ENFORCE_NOT_NULL(
      in_tensor,
      platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
271 272 273 274 275 276 277 278 279 280 281 282
  if (in_tensor->IsInitialized()) {
    if (numel) *numel = in_tensor->numel();
    return in_tensor->data<T>();
  } else {
    if (numel) *numel = 0;
    return nullptr;
  }
}

template <typename T, bool AllowNotExist = false>
static T *GetSameInOutTensorPtr(const framework::ExecutionContext &ctx,
                                const platform::Place &place,
283 284
                                const char *in_name,
                                const char *out_name,
285 286 287
                                int64_t *numel = nullptr) {
  const auto *in_tensor = ctx.Input<framework::Tensor>(in_name);
  if (in_tensor == nullptr || !in_tensor->IsInitialized()) {
288 289
    PADDLE_ENFORCE_EQ(AllowNotExist,
                      true,
290 291 292 293 294 295 296
                      platform::errors::InvalidArgument(
                          "Input(%s) cannot be NULL.", in_name));
    if (numel) *numel = 0;
    return nullptr;
  }

  auto *out_tensor = ctx.Output<framework::Tensor>(out_name);
297 298 299
  PADDLE_ENFORCE_NOT_NULL(
      in_tensor,
      platform::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
300 301 302 303 304
  PADDLE_ENFORCE_NOT_NULL(out_tensor,
                          platform::errors::InvalidArgument(
                              "Output(%s) cannot be NULL.", out_name));
  const T *in_data = in_tensor->data<T>();
  T *out_data = out_tensor->mutable_data<T>(place);
305 306
  PADDLE_ENFORCE_EQ(in_data,
                    out_data,
307 308
                    platform::errors::InvalidArgument(
                        "Input(%s) and Output(%s) must be the same Tensor.",
309 310
                        in_name,
                        out_name));
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
  if (numel) *numel = out_tensor->numel();
  return out_data;
}

template <typename T>
struct SquareFunctor {
  HOSTDEVICE MasterT<T> operator()(T x) const {
    auto y = static_cast<MasterT<T>>(x);
    return y * y;
  }
};

template <typename T>
struct IsNanInfFunctor {
  HOSTDEVICE bool operator()(T x) const { return !isfinite(x); }
};

struct OrFunctor {
  HOSTDEVICE bool operator()(bool x, bool y) const { return x || y; }
};

struct AndFunctor {
  HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; }
};

S
sneaxiy 已提交
336
template <typename T1, typename T2, int VecSize>
337 338
static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x,
                                       const T2 *__restrict__ scale,
339 340
                                       T1 *__restrict__ y,
                                       int num) {
341 342 343
  static_assert(sizeof(T1) <= sizeof(T2),
                "sizeof(T1) must be not greater than sizeof(T2).");
  T2 s = scale[0];
S
sneaxiy 已提交
344 345 346 347 348

  int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  int stride = blockDim.x * gridDim.x * VecSize;

  for (; i + VecSize <= num; i += stride) {
349 350
    phi::AlignedVector<T1, VecSize> x_vec;
    phi::AlignedVector<T1, VecSize> y_vec;
S
sneaxiy 已提交
351

352
    phi::Load(x + i, &x_vec);
S
sneaxiy 已提交
353 354 355 356
#pragma unroll
    for (int j = 0; j < VecSize; ++j) {
      y_vec[j] = static_cast<T1>(static_cast<T2>(x_vec[j]) * s);
    }
357
    phi::Store(y_vec, y + i);
S
sneaxiy 已提交
358 359 360
  }

  for (; i < num; ++i) {
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
    y[i] = static_cast<T1>(static_cast<T2>(x[i]) * s);
  }
}

template <typename T>
static __global__ void AddToCUDAKernel(const T *__restrict__ x,
                                       T *__restrict__ y) {
  y[0] += x[0];
}

// If clip before allreduce,
// coeff = global_scale * max_global_grad_norm / (1e-6 + sqrt(square_grad_norm)
// * rescale_grad)
// if coeff >= 1 or coeff is Nan/Inf, scale = 1.0
// else scale = coeff
template <typename T1, typename T2>
static __global__ void CalcGradNormClipBeforeAllReduceScale(
378 379 380 381 382 383
    const T1 *__restrict__ global_scale,
    T1 max_global_grad_norm,
    const T1 *__restrict__ square_grad_norm,
    T1 *__restrict__ out1,
    T2 *__restrict__ out2,
    T1 clip_rescale_grad) {
384
  T1 grad_norm = static_cast<T1>(sqrtf(*square_grad_norm)) * clip_rescale_grad;
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
  T1 scale = global_scale[0] * max_global_grad_norm / (1e-6 + grad_norm);
  bool found_nan_inf = !isfinite(scale);
  if (scale >= 1 || found_nan_inf) {
    scale = static_cast<T1>(1.0);
  }

  if (out1) {
    *out1 = scale;
  }
  if (out2) {
    *out2 = static_cast<T2>(scale);
  }
}

static __global__ void SetNanInfValueCUDAKernelOneFlag(const bool *in_flag_p,
                                                       float *out_p) {
  *out_p = (*in_flag_p) ? __int_as_float(0x7fffffffU) : 0.0f;
}

static __global__ void SetNanInfValueCUDAKernelTwoFlag(const bool *in_flag_p_1,
                                                       const bool *in_flag_p_2,
                                                       float *out_p) {
  *out_p =
      ((*in_flag_p_1) || (*in_flag_p_2)) ? __int_as_float(0x7fffffffU) : 0.0f;
}

411 412
template <typename T, typename GradT, int VecSize>
static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
413 414
    const T *__restrict__ param_p,
    const GradT *__restrict__ grad_p,
415
    const T *__restrict__ square_grad_norm_p,
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
    const T *__restrict__ global_scale,
    const T *__restrict__ beta1pow_p,
    const T *__restrict__ beta2pow_p,
    T *__restrict__ mom1_p,
    T *__restrict__ mom2_p,
    T *__restrict__ trust_ratio_div_p,
    bool *__restrict__ found_inf,
    int64_t *__restrict__ step,
    T weight_decay,
    int weight_decay_end_numel,
    T beta1,
    T beta2,
    T epsilon,
    T max_global_grad_norm,
    int num,
    T rescale_grad) {
432
  T square_grad_norm = *square_grad_norm_p;
433 434 435 436 437 438 439
  bool need_update_found_inf =
      (found_inf && threadIdx.x == 0 && blockIdx.x == 0);
  if (!isfinite(square_grad_norm)) {
    if (need_update_found_inf) *found_inf = true;
    return;
  } else if (need_update_found_inf) {
    *found_inf = false;
440
    ++(*step);
441
  }
442 443 444 445 446 447 448 449 450 451 452 453 454

  T scale = rescale_grad / global_scale[0];
  if (max_global_grad_norm > 0) {
    T clip_scale =
        max_global_grad_norm / (sqrtf(square_grad_norm) * scale + 1e-6);
    if (clip_scale < static_cast<T>(1)) {
      scale *= clip_scale;
    }
  }

  T one_minus_beta1pow = 1 - beta1pow_p[0];
  T one_minus_beta2pow = 1 - beta2pow_p[0];

455 456 457 458
  int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  int stride = blockDim.x * gridDim.x * VecSize;

  for (; i + VecSize <= num; i += stride) {
459 460 461 462 463
    phi::AlignedVector<T, VecSize> param_vec;
    phi::AlignedVector<GradT, VecSize> grad_vec;
    phi::AlignedVector<T, VecSize> mom1_vec;
    phi::AlignedVector<T, VecSize> mom2_vec;
    phi::AlignedVector<T, VecSize> trust_ratio_div_vec;
464 465 466

    T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay;
    if (cur_weight_decay != static_cast<T>(0.0)) {
467
      phi::Load(param_p + i, &param_vec);
468 469 470 471 472 473
    } else {
#pragma unroll
      for (int j = 0; j < VecSize; ++j) {
        param_vec[j] = static_cast<T>(0);
      }
    }
474 475 476
    phi::Load(grad_p + i, &grad_vec);
    phi::Load(mom1_p + i, &mom1_vec);
    phi::Load(mom2_p + i, &mom2_vec);
477

478 479
#define PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(                                    \
    __param, __grad, __mom1, __mom2, __trust_ratio_div, __idx)                 \
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
  T p = __param[__idx];                                                        \
  T g = static_cast<T>(__grad[__idx]) * scale;                                 \
  T mom1 = __mom1[__idx];                                                      \
  T mom2 = __mom2[__idx];                                                      \
  mom1 = beta1 * mom1 + (1 - beta1) * g;                                       \
  mom2 = beta2 * mom2 + (1 - beta2) * g * g;                                   \
  T mom1_unbiased = mom1 / one_minus_beta1pow;                                 \
  T mom2_unbiased = mom2 / one_minus_beta2pow;                                 \
  __trust_ratio_div[__idx] =                                                   \
      mom1_unbiased / (sqrtf(mom2_unbiased) + epsilon) + cur_weight_decay * p; \
  __mom1[__idx] = mom1;                                                        \
  __mom2[__idx] = mom2;

#pragma unroll
    for (int j = 0; j < VecSize; ++j) {
495 496
      PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(
          param_vec, grad_vec, mom1_vec, mom2_vec, trust_ratio_div_vec, j);
497 498
    }

499 500 501
    phi::Store(mom1_vec, mom1_p + i);
    phi::Store(mom2_vec, mom2_p + i);
    phi::Store(trust_ratio_div_vec, trust_ratio_div_p + i);
502 503 504 505
  }

  for (; i < num; ++i) {
    T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay;
506 507
    PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(
        param_p, grad_p, mom1_p, mom2_p, trust_ratio_div_p, i);
508 509 510
  }
}

511 512
template <typename T, typename GradT>
static void MultiTensorUpdateLambMomentAndTrustRatioDiv(
L
Leo Chen 已提交
513
    const phi::GPUContext &dev_ctx,
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
    const int *offsets,
    int n,
    const T *param_p,
    const GradT *grad_p,
    const T *square_grad_norm_p,
    const T *global_scale,
    const T *beta1pow_p,
    const T *beta2pow_p,
    T *mom1_p,
    T *mom2_p,
    T *trust_ratio_div_p,
    bool *found_inf_p,
    int64_t *step,
    T weight_decay,
    int weight_decay_end_idx,
    T beta1,
    T beta2,
    T epsilon,
    T max_global_grad_norm,
    T rescale_grad) {
534 535
  if (n <= 0) return;
  int numel = offsets[n] - offsets[0];
536 537
  PADDLE_ENFORCE_GE(weight_decay_end_idx,
                    0,
538 539
                    platform::errors::InvalidArgument(
                        "The weight decay end index should be >= 0."));
540 541
  PADDLE_ENFORCE_LE(weight_decay_end_idx,
                    n,
542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
                    platform::errors::InvalidArgument(
                        "The weight decay end index should be < %d.", n));
  auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0];

  int vec_size = GetChunkedVecSize(param_p, 0);
  vec_size = std::min(vec_size, GetChunkedVecSize(grad_p, 0));
  vec_size = std::min(vec_size, GetChunkedVecSize(mom1_p, 0));
  vec_size = std::min(vec_size, GetChunkedVecSize(mom2_p, 0));
  vec_size = std::min(vec_size, GetChunkedVecSize(trust_ratio_div_p, 0));
  for (int i = 0; i < n; ++i) {
    auto length = offsets[i + 1] - offsets[i];
    while (length % vec_size != 0) {
      vec_size /= 2;
    }
  }

  VLOG(1) << __func__ << " VecSize = " << vec_size;

  auto stream = dev_ctx.stream();
  auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
562 563
  if (found_inf_p == nullptr) {
    PADDLE_ENFORCE_EQ(
564 565
        step,
        nullptr,
566 567 568
        platform::errors::InvalidArgument(
            "Output(Step) cannot be updated twice in one mini-batch."));
  } else {
569 570 571
    PADDLE_ENFORCE_NOT_NULL(
        step,
        platform::errors::InvalidArgument("Output(Step) cannot be nullptr."));
572
  }
573

574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL                        \
  do {                                                                   \
    UpdateLambMomentAndTrustRatioDivCUDAKernel<T, GradT, kVecSize>       \
        <<<config.block_per_grid, config.thread_per_block, 0, stream>>>( \
            param_p,                                                     \
            grad_p,                                                      \
            square_grad_norm_p,                                          \
            global_scale,                                                \
            beta1pow_p,                                                  \
            beta2pow_p,                                                  \
            mom1_p,                                                      \
            mom2_p,                                                      \
            trust_ratio_div_p,                                           \
            found_inf_p,                                                 \
            step,                                                        \
            weight_decay,                                                \
            weight_decay_end_numel,                                      \
            beta1,                                                       \
            beta2,                                                       \
            epsilon,                                                     \
            max_global_grad_norm,                                        \
            numel,                                                       \
            rescale_grad);                                               \
597 598 599 600 601 602
  } while (0)

  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL);
#undef PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL
}

603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633
template <typename T, bool NeedUpdate /*=true*/>
struct LambBetaPowUpdateOnceHelper {
  LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
    PADDLE_ENFORCE_NOT_NULL(beta1pow,
                            platform::errors::InvalidArgument(
                                "The beta1pow should not be nullptr."));
    PADDLE_ENFORCE_NOT_NULL(beta2pow,
                            platform::errors::InvalidArgument(
                                "The beta2pow should not be nullptr."));
    beta1pow_ = beta1pow;
    beta2pow_ = beta2pow;
    beta1_ = beta1;
    beta2_ = beta2;
  }

  HOSTDEVICE void UpdateBetaPows() const {
    beta1pow_[0] *= beta1_;
    beta2pow_[0] *= beta2_;
  }

 private:
  T *__restrict__ beta1pow_;
  T *__restrict__ beta2pow_;
  T beta1_;
  T beta2_;
};

template <typename T>
struct LambBetaPowUpdateOnceHelper<T, false> {
  LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) {
    PADDLE_ENFORCE_EQ(
634 635
        beta1pow,
        nullptr,
636 637
        platform::errors::InvalidArgument("The beta1pow should be nullptr."));
    PADDLE_ENFORCE_EQ(
638 639
        beta2pow,
        nullptr,
640 641 642 643 644 645 646 647 648 649
        platform::errors::InvalidArgument("The beta2pow should be nullptr."));
  }

  HOSTDEVICE void UpdateBetaPows() const {}
};

template <typename T, bool HasMasterParam /*=true*/>
struct LambParamHelper {
  LambParamHelper(T *param, MasterT<T> *master_param) {
    constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
650 651
    PADDLE_ENFORCE_EQ(kIsSameType,
                      false,
652 653 654 655 656 657 658 659 660
                      platform::errors::InvalidArgument(
                          "T must not be the same with MasterT<T>."));
    PADDLE_ENFORCE_NOT_NULL(master_param,
                            platform::errors::InvalidArgument(
                                "Master parameter must be provided."));
    param_ = param;
    master_param_ = master_param;
  }

661
  HOSTDEVICE T *__restrict__ ParamPtr() { return param_; }
662

663
  HOSTDEVICE MasterT<T> *__restrict__ MasterParamPtr() { return master_param_; }
664 665 666 667 668 669 670 671 672 673

 private:
  T *__restrict__ param_;
  MasterT<T> *__restrict__ master_param_;
};

template <typename T>
struct LambParamHelper<T, false> {
  LambParamHelper(T *param, MasterT<T> *master_param) {
    constexpr bool kIsSameType = std::is_same<T, MasterT<T>>::value;
674 675
    PADDLE_ENFORCE_EQ(kIsSameType,
                      true,
676 677 678 679 680 681 682 683 684 685 686 687
                      platform::errors::InvalidArgument(
                          "T must be the same with MasterT<T>."));
    if (master_param != nullptr) {
      PADDLE_ENFORCE_EQ(static_cast<void *>(param),
                        static_cast<void *>(master_param),
                        platform::errors::InvalidArgument(
                            "Master parameter must be nullptr or the same as "
                            "non-master parameter."));
    }
    param_ = param;
  }

688
  HOSTDEVICE T *__restrict__ ParamPtr() { return param_; }
689

690
  HOSTDEVICE constexpr MasterT<T> *MasterParamPtr() { return nullptr; }
691 692 693 694 695

 private:
  T *__restrict__ param_;
};

696 697 698
template <typename ParamT,
          bool HasMasterParam,
          bool NeedUpdateBetaPow,
699 700 701
          int VecSize>
struct LambUpdateParamAndBetaPowsFunctor {
  DEVICE void operator()(
702 703 704 705
      int tensor_id,
      int chunk_id,
      int offset,
      int size,
706
      LambParamHelper<ParamT, HasMasterParam> param_helper,
707 708
      const MasterT<ParamT> *trust_ratio_div,
      const MasterT<ParamT> *lr,
709
      const MasterT<ParamT> *param_square_norm,
710 711
      const MasterT<ParamT> *trust_ratio_div_square_norm,
      const bool *found_inf,
712 713 714 715 716
      LambBetaPowUpdateOnceHelper<MasterT<ParamT>, NeedUpdateBetaPow>
          betapow_helper) const {
    if (*found_inf) return;

    using MT = MasterT<ParamT>;
717

718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
    MT p_square_norm = param_square_norm[tensor_id];
    MT t_square_norm = trust_ratio_div_square_norm[tensor_id];
    MT lr_value = *lr;
    MT ratio = (p_square_norm != static_cast<MT>(0) &&
                        t_square_norm != static_cast<MT>(0)
                    ? lr_value * sqrtf(p_square_norm / t_square_norm)
                    : lr_value);

    int i;
    int stride = blockDim.x * VecSize;

    ParamT *param = param_helper.ParamPtr() + offset;
    MT *master_param = HasMasterParam ? param_helper.MasterParamPtr() + offset
                                      : param_helper.MasterParamPtr();
    trust_ratio_div += offset;

    for (i = threadIdx.x * VecSize; i + VecSize <= size; i += stride) {
735 736
      phi::AlignedVector<MT, VecSize> trust_ratio_div_vec;
      phi::Load(trust_ratio_div + i, &trust_ratio_div_vec);
737
      if (HasMasterParam) {
738 739 740
        phi::AlignedVector<MT, VecSize> master_param_vec;
        phi::Load(master_param + i, &master_param_vec);
        phi::AlignedVector<ParamT, VecSize> param_vec;
741 742 743 744 745 746
#pragma unroll
        for (int j = 0; j < VecSize; ++j) {
          MT p = master_param_vec[j] - ratio * trust_ratio_div_vec[j];
          master_param_vec[j] = p;
          param_vec[j] = static_cast<ParamT>(p);
        }
747 748
        phi::Store(master_param_vec, master_param + i);
        phi::Store(param_vec, param + i);
749
      } else {
750 751
        phi::AlignedVector<ParamT, VecSize> param_vec;
        phi::Load(param + i, &param_vec);
752 753 754 755 756
#pragma unroll
        for (int j = 0; j < VecSize; ++j) {
          MT p = static_cast<MT>(param_vec[j]) - ratio * trust_ratio_div_vec[j];
          param_vec[j] = static_cast<ParamT>(p);
        }
757
        phi::Store(param_vec, param + i);
758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773
      }
    }

    for (; i < size; ++i) {
      if (HasMasterParam) {
        MT p = master_param[i] - ratio * trust_ratio_div[i];
        master_param[i] = p;
        param[i] = static_cast<ParamT>(p);
      } else {
        MT p = static_cast<MT>(param[i]) - ratio * trust_ratio_div[i];
        param[i] = static_cast<ParamT>(p);
      }
    }

    if (NeedUpdateBetaPow && threadIdx.x == 0 && blockIdx.x == 0) {
      betapow_helper.UpdateBetaPows();
774 775
    }
  }
776
};
777

778
// TODO(zengjinle): which block_dim and chunk_size would be better?
779 780
template <typename ParamT,
          int MaxTensorNumPerLaunch = 160,
781 782
          int MaxChunkNumPerLaunch = 780>
static void MultiTensorUpdateLambParamAndBetaPows(
L
Leo Chen 已提交
783
    const phi::GPUContext &dev_ctx,
784 785 786 787
    const int *offsets,
    int n,
    const MasterT<ParamT> *trust_ratio_div,
    const MasterT<ParamT> *lr,
788
    const MasterT<ParamT> *param_square_norm,
789 790 791 792 793 794 795 796
    const MasterT<ParamT> *trust_ratio_div_square_norm,
    const bool *found_inf,
    ParamT *param,
    MasterT<ParamT> *master_param,
    MasterT<ParamT> *beta1pow,
    MasterT<ParamT> *beta2pow,
    MasterT<ParamT> beta1,
    MasterT<ParamT> beta2,
797 798 799 800 801 802
    int chunk_size = 65536) {
  constexpr bool kHasMasterParam =
      !(std::is_same<ParamT, MasterT<ParamT>>::value);

  bool has_beta_pow = (beta1pow != nullptr);
  if (has_beta_pow) {
803 804 805
    PADDLE_ENFORCE_NOT_NULL(
        beta2pow,
        platform::errors::InvalidArgument("Beta2Pow should not be nullptr."));
806
  } else {
807
    PADDLE_ENFORCE_EQ(
808 809
        beta2pow,
        nullptr,
810
        platform::errors::InvalidArgument("Beta2Pow should be nullptr."));
811 812
  }

813 814 815
#ifdef PADDLE_WITH_HIP
  const int block_dim = 256;
#else
816
  const int block_dim = 512;
817
#endif
818

819 820 821 822 823 824 825 826 827 828 829 830
  int vec_size = 8;
  for (int i = 0; i < n; ++i) {
    int offset = offsets[i] - offsets[0];
    vec_size =
        std::min(vec_size, GetChunkedVecSize(param + offset, chunk_size));
    if (kHasMasterParam) {
      vec_size = std::min(vec_size,
                          GetChunkedVecSize(master_param + offset, chunk_size));
    }
    vec_size = std::min(
        vec_size, GetChunkedVecSize(trust_ratio_div + offset, chunk_size));
  }
831

832
  VLOG(1) << __func__ << " VecSize = " << vec_size;
833

834 835
  constexpr auto kNumTensor = MaxTensorNumPerLaunch;
  constexpr auto kNumChunk = MaxChunkNumPerLaunch;
836

837
  auto stream = dev_ctx.stream();
838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855
#define PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(__has_beta_pow)      \
  do {                                                                   \
    using FunctorT = LambUpdateParamAndBetaPowsFunctor<ParamT,           \
                                                       kHasMasterParam,  \
                                                       __has_beta_pow,   \
                                                       kVecSize>;        \
    LambParamHelper<ParamT, kHasMasterParam> param_helper(param,         \
                                                          master_param); \
    LambBetaPowUpdateOnceHelper<MasterT<ParamT>, __has_beta_pow>         \
        betapow_helper(beta1pow, beta2pow, beta1, beta2);                \
    launcher.Launch(FunctorT(),                                          \
                    param_helper,                                        \
                    trust_ratio_div,                                     \
                    lr,                                                  \
                    param_square_norm,                                   \
                    trust_ratio_div_square_norm,                         \
                    found_inf,                                           \
                    betapow_helper);                                     \
856
  } while (0)
857

858 859 860 861 862 863 864 865 866 867 868 869 870 871 872
#define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE            \
  do {                                                                  \
    auto callback =                                                     \
        [&](const MultiTensorLauncher<kNumTensor, kNumChunk> &launcher, \
            int launch_n) {                                             \
          if (has_beta_pow && launch_n == 0) {                          \
            PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true);          \
            beta1pow = nullptr;                                         \
            beta2pow = nullptr;                                         \
          } else {                                                      \
            PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false);         \
          }                                                             \
        };                                                              \
    MultiTensorApplyWithCallback<kNumTensor, kNumChunk>(                \
        stream, offsets, n, chunk_size, block_dim, callback);           \
873 874
  } while (0)

875 876
  PD_VEC_LAUNCH_KERNEL(vec_size,
                       PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE);
877

878 879
#undef PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW
#undef PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE
880 881 882 883
}

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
884 885
                                           ncclComm_t comm,
                                           const void *scale,
886 887 888 889 890 891 892 893 894 895 896 897 898 899 900
                                           ncclRedOp_t *op) {
#if NCCL_VERSION_CODE >= 21100
  int ver;
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetVersion(&ver));
  if (ver >= 21100) {
    VLOG(10) << "ncclRedOpCreatePreMulSum is supported.";
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpCreatePreMulSum(
        op, const_cast<void *>(scale), dtype, ncclScalarDevice, comm));
    return true;
  }
#endif
  VLOG(10) << "ncclRedOpCreatePreMulSum is not supported.";
  return false;
}

S
sneaxiy 已提交
901
template <typename T1, typename T2>
L
Leo Chen 已提交
902
static void LaunchScaleKernel(const phi::GPUContext &dev_ctx,
903 904 905 906
                              const T1 *x,
                              const T2 *scale,
                              T1 *y,
                              int n,
S
sneaxiy 已提交
907 908 909 910
                              gpuStream_t stream) {
  int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
  auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);

911 912 913 914 915
#define PD_LAMB_VEC_SCALE_KERNEL_CASE                                    \
  do {                                                                   \
    ScaleCUDAKernel<T1, T2, kVecSize>                                    \
        <<<config.block_per_grid, config.thread_per_block, 0, stream>>>( \
            x, scale, y, n);                                             \
S
sneaxiy 已提交
916 917 918 919 920 921
  } while (0)

  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAMB_VEC_SCALE_KERNEL_CASE);
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE
}

922
template <typename T, bool UseReduceScatter>
923 924 925 926 927 928
static void NCCLSumWithScaleBase(const T *sendbuff,
                                 T *recvbuff,
                                 size_t recvcount,
                                 size_t nranks,
                                 ncclComm_t comm,
                                 gpuStream_t stream,
L
Leo Chen 已提交
929
                                 const phi::GPUContext &dev_ctx,
930
                                 const T *scale = nullptr) {
931 932 933 934 935
  static_assert(std::is_same<T, float>::value ||
                    std::is_same<T, platform::float16>::value,
                "T must be either float32 or float16.");
  if (recvcount == 0) return;

936
  auto numel = UseReduceScatter ? (recvcount * nranks) : recvcount;
937 938
  if (comm == nullptr) {
    if (scale != nullptr) {
939 940
      PADDLE_ENFORCE_EQ(nranks,
                        1,
941 942
                        platform::errors::InvalidArgument(
                            "nranks must be 1 when scale != nullptr."));
943
      LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream);
944 945 946 947 948 949 950 951 952 953 954 955
    }
    return;
  }

  ncclRedOp_t op = ncclSum;
  ncclDataType_t dtype =
      std::is_same<T, float>::value ? ncclFloat32 : ncclFloat16;
  bool should_destroy_op =
      scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op);
  memory::Buffer buffer(dev_ctx.GetPlace());
  if (scale && !should_destroy_op) {
    T *new_sendbuff = buffer.Alloc<T>(numel);
S
sneaxiy 已提交
956
    LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
957 958 959
    sendbuff = new_sendbuff;
  }

960 961 962 963 964 965 966
  if (UseReduceScatter) {
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
        sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
  } else {
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, recvcount, dtype, op, comm, stream));
  }
967 968 969 970 971 972 973 974 975

#if NCCL_VERSION_CODE >= 21100
  if (should_destroy_op) {
    VLOG(10) << "ncclRedOpDestroy starts";
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpDestroy(op, comm));
    VLOG(10) << "ncclRedOpDestroy ends";
  }
#endif
}
976 977

template <typename T>
L
Leo Chen 已提交
978 979 980 981 982 983 984 985
static void NCCLReduceScatterWithScale(const T *sendbuff,
                                       T *recvbuff,
                                       size_t recvcount,
                                       size_t nranks,
                                       ncclComm_t comm,
                                       gpuStream_t stream,
                                       const phi::GPUContext &dev_ctx,
                                       const T *scale = nullptr) {
986 987
  NCCLSumWithScaleBase<T, true>(
      sendbuff, recvbuff, recvcount, nranks, comm, stream, dev_ctx, scale);
988 989 990
}

template <typename T>
991 992 993 994 995 996
static void NCCLAllReduceWithScale(const T *sendbuff,
                                   T *recvbuff,
                                   size_t recvcount,
                                   size_t nranks,
                                   ncclComm_t comm,
                                   gpuStream_t stream,
L
Leo Chen 已提交
997
                                   const phi::GPUContext &dev_ctx,
998
                                   const T *scale = nullptr) {
999 1000
  NCCLSumWithScaleBase<T, false>(
      sendbuff, recvbuff, recvcount, nranks, comm, stream, dev_ctx, scale);
1001 1002
}

1003 1004
#endif

1005 1006 1007
template <typename InputIteratorT,
          typename OutputIteratorT,
          typename ReduceOpT,
1008
          typename T>
1009 1010 1011 1012 1013 1014 1015
static void CubDeviceReduce(InputIteratorT d_in,
                            OutputIteratorT d_out,
                            int num_items,
                            ReduceOpT reduction_op,
                            T init,
                            gpuStream_t stream,
                            memory::Buffer *buffer) {
1016 1017
  void *d_temp_storage = nullptr;
  size_t temp_storage_bytes = 0;
1018 1019 1020 1021 1022 1023 1024 1025
  PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage,
                                                       temp_storage_bytes,
                                                       d_in,
                                                       d_out,
                                                       num_items,
                                                       reduction_op,
                                                       init,
                                                       stream));
1026 1027 1028
  d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
  VLOG(10) << "cub::DeviceReduce::Reduce needs " << temp_storage_bytes
           << " byte(s), ptr = " << d_temp_storage;
1029 1030 1031 1032 1033 1034 1035 1036
  PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage,
                                                       temp_storage_bytes,
                                                       d_in,
                                                       d_out,
                                                       num_items,
                                                       reduction_op,
                                                       init,
                                                       stream));
1037 1038 1039
}

template <typename T>
1040 1041 1042
static void GetSquareGradNormImpl(const T *grad,
                                  int n,
                                  float *square_norm,
1043 1044 1045 1046 1047
                                  gpuStream_t stream,
                                  memory::Buffer *cub_tmp_buffer) {
  using Iterator =
      cub::TransformInputIterator<float, SquareFunctor<T>, const T *>;
  Iterator iter(grad, SquareFunctor<T>());
1048 1049 1050 1051 1052 1053 1054
  CubDeviceReduce(iter,
                  square_norm,
                  n,
                  cub::Sum(),
                  static_cast<float>(0),
                  stream,
                  cub_tmp_buffer);
1055 1056 1057
}

// square_norm is of length 2 at least
1058 1059
static void GetSquareGradNorm(const float *fp32_grad,
                              int fp32_numel,
1060
                              const platform::float16 *fp16_grad,
1061 1062
                              int fp16_numel,
                              float *square_norm,
1063 1064 1065 1066 1067
                              gpuStream_t stream,
                              memory::Buffer *cub_tmp_buffer) {
  VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel
           << " , fp16_numel = " << fp16_numel;
  if (fp32_numel > 0) {
1068 1069
    GetSquareGradNormImpl(
        fp32_grad, fp32_numel, square_norm, stream, cub_tmp_buffer);
1070 1071 1072 1073 1074 1075
    VLOG(10) << "FP32 square L2-Norm: "
             << FlattenToString(square_norm, 1, cub_tmp_buffer->GetPlace());
  }

  if (fp16_numel > 0) {
    float *fp16_square_norm = fp32_numel > 0 ? square_norm + 1 : square_norm;
1076 1077
    GetSquareGradNormImpl(
        fp16_grad, fp16_numel, fp16_square_norm, stream, cub_tmp_buffer);
1078
    VLOG(10) << "FP16 square L2-Norm: "
1079 1080
             << FlattenToString(
                    fp16_square_norm, 1, cub_tmp_buffer->GetPlace());
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098
    if (fp32_numel > 0) {
      AddToCUDAKernel<<<1, 1, 0, stream>>>(fp16_square_norm, square_norm);
      VLOG(10) << "FP32+FP16 square L2-Norm: "
               << FlattenToString(square_norm, 1, cub_tmp_buffer->GetPlace());
    }
  }
  VLOG(10) << "GetSquareGradNorm ends, fp32_numel = " << fp32_numel
           << " , fp16_numel = " << fp16_numel;
}

template <typename T>
std::string NumToString(T x) {
  std::stringstream ss;
  ss << x;
  return ss.str();
}

template <typename T>
1099 1100
static std::string GetMinMaxStr(const T *x,
                                size_t n,
1101 1102
                                const platform::Place &place) {
  PADDLE_ENFORCE_EQ(
1103 1104
      platform::is_gpu_place(place),
      true,
1105 1106
      platform::errors::InvalidArgument("Only support CUDAPlace currently."));

L
Leo Chen 已提交
1107
  auto *dev_ctx = static_cast<phi::GPUContext *>(
1108 1109 1110 1111 1112 1113 1114 1115
      platform::DeviceContextPool::Instance().Get(place));
  auto stream = dev_ctx->stream();

  memory::Buffer ret_buffer(place);
  T *ret = ret_buffer.Alloc<T>(2);

  if (n > 0) {
    memory::Buffer cub_buffer(place);
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
    CubDeviceReduce(x,
                    ret,
                    n,
                    cub::Min(),
                    std::numeric_limits<T>::max(),
                    stream,
                    &cub_buffer);
    CubDeviceReduce(x,
                    ret + 1,
                    n,
                    cub::Max(),
                    std::numeric_limits<T>::lowest(),
                    stream,
                    &cub_buffer);
1130 1131
    T ret_cpu[2];
#ifdef PADDLE_WITH_HIP
1132 1133
    PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(
        &ret_cpu[0], ret, 2 * sizeof(T), hipMemcpyDeviceToHost, stream));
1134 1135
    PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(stream));
#else
1136 1137
    PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
        &ret_cpu[0], ret, 2 * sizeof(T), cudaMemcpyDeviceToHost, stream));
1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#endif
    return std::string("{\"min\": ") + NumToString(ret_cpu[0]) +
           " , \"max\": " + NumToString(ret_cpu[1]) + "}";
  } else {
    return "{\"min\": null, \"max\": null}";
  }
}

struct VisitDTypeFunctor {
  VisitDTypeFunctor(const framework::Tensor *x, std::string *s)
      : x_(x), s_(s) {}

  template <typename T>
  void apply() const {
    *s_ = GetMinMaxStr<T>(x_->template data<T>(), x_->numel(), x_->place());
  }

 private:
  const framework::Tensor *x_;
  std::string *s_;
};

static std::string GetMinMaxStr(const framework::Tensor *x) {
  if (x == nullptr) return "null";
  if (!x->IsInitialized()) return "not_inited";
  if (!platform::is_gpu_place(x->place())) return "CPUTensor";
  std::string str;
  VisitDTypeFunctor functor(x, &str);
1167
  phi::VisitDataType(x->dtype(), functor);
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
  return str;
}

static void PrintAllMinMaxRange(const framework::ExecutionContext &ctx,
                                bool only_inputs) {
  if (!VLOG_IS_ON(1)) return;
  for (const auto &pair : ctx.GetOp().Inputs()) {
    const auto &key = pair.first;
    const auto tensors = ctx.MultiInput<framework::Tensor>(key);
    size_t n = tensors.size();
    for (size_t i = 0; i < n; ++i) {
      VLOG(1) << "Input(" << key + ")[" << i << "] = " << pair.second[i]
              << " , " << GetMinMaxStr(tensors[i]);
    }
  }

  if (only_inputs) return;
  for (const auto &pair : ctx.GetOp().Outputs()) {
    const auto &key = pair.first;
    const auto tensors = ctx.MultiOutput<framework::Tensor>(key);
    size_t n = tensors.size();
    for (size_t i = 0; i < n; ++i) {
      VLOG(1) << "Output(" << key + ")[" << i << "] = " << pair.second[i]
              << " , " << GetMinMaxStr(tensors[i]);
    }
  }
}

1196 1197
static void CheckHasNanInfGrad(const float *fp32_grad,
                               int fp32_numel,
1198
                               const platform::float16 *fp16_grad,
1199 1200
                               int fp16_numel,
                               float *nan_inf_flag,
1201 1202 1203 1204 1205 1206 1207
                               gpuStream_t stream,
                               memory::Buffer *cub_tmp_buffer) {
  bool *fp32_has_nan_inf = nullptr;
  bool *fp16_has_nan_inf = nullptr;
  if (fp32_numel > 0) {
    fp32_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1);
    cub::TransformInputIterator<bool, IsNanInfFunctor<float>, const float *>
1208
        iter(fp32_grad, IsNanInfFunctor<float>());
1209 1210 1211 1212 1213 1214 1215
    CubDeviceReduce(iter,
                    fp32_has_nan_inf,
                    fp32_numel,
                    OrFunctor(),
                    false,
                    stream,
                    cub_tmp_buffer);
1216 1217 1218 1219
  }

  if (fp16_numel > 0) {
    fp16_has_nan_inf = reinterpret_cast<bool *>(nan_inf_flag + 1) + 1;
1220 1221
    cub::TransformInputIterator<bool,
                                IsNanInfFunctor<platform::float16>,
1222 1223
                                const platform::float16 *>
        iter(fp16_grad, IsNanInfFunctor<platform::float16>());
1224 1225 1226 1227 1228 1229 1230
    CubDeviceReduce(iter,
                    fp16_has_nan_inf,
                    fp16_numel,
                    OrFunctor(),
                    false,
                    stream,
                    cub_tmp_buffer);
1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
  }

  if (fp32_has_nan_inf && fp16_has_nan_inf) {
    SetNanInfValueCUDAKernelTwoFlag<<<1, 1, 0, stream>>>(
        fp32_has_nan_inf, fp16_has_nan_inf, nan_inf_flag);
  } else if (fp32_has_nan_inf) {
    SetNanInfValueCUDAKernelOneFlag<<<1, 1, 0, stream>>>(fp32_has_nan_inf,
                                                         nan_inf_flag);
  } else {
    SetNanInfValueCUDAKernelOneFlag<<<1, 1, 0, stream>>>(fp16_has_nan_inf,
                                                         nan_inf_flag);
  }
}

1245 1246
template <typename T1, typename T2, typename T3, int VecSize>
static __global__ void ElementwiseAddWithCastCUDAKernel(const T1 *x,
1247 1248
                                                        const T2 *y,
                                                        T3 *z,
1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
                                                        int n) {
  static_assert(sizeof(T1) <= sizeof(T2),
                "sizeof(T1) must be smaller than sizeof(T2).");
  using MT = MasterT<T2>;

  int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  int stride = (blockDim.x * gridDim.x) * VecSize;
  for (; i + VecSize <= n; i += stride) {
    phi::AlignedVector<T1, VecSize> x_vec;
    phi::AlignedVector<T2, VecSize> y_vec;
    phi::AlignedVector<T3, VecSize> z_vec;
    phi::Load(x + i, &x_vec);
    phi::Load(y + i, &y_vec);
#pragma unroll
    for (int j = 0; j < VecSize; ++j) {
      auto x_tmp = static_cast<MT>(x_vec[j]);
      auto y_tmp = static_cast<MT>(y_vec[j]);
      z_vec[j] = static_cast<T3>(x_tmp + y_tmp);
    }
    phi::Store(z_vec, z + i);
  }

  for (; i < n; ++i) {
    auto x_tmp = static_cast<MT>(x[i]);
    auto y_tmp = static_cast<MT>(y[i]);
    z[i] = static_cast<T3>(x_tmp + y_tmp);
  }
}

template <typename T1, typename T2, typename T3>
L
Leo Chen 已提交
1279 1280 1281 1282 1283 1284
static void LaunchElementwiseAddWithCastKernel(const phi::GPUContext &dev_ctx,
                                               const T1 *x,
                                               const T2 *y,
                                               T3 *z,
                                               int n,
                                               gpuStream_t stream) {
1285 1286 1287 1288 1289
  int vec_size =
      std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)),
               GetChunkedVecSize(z, 0));
  auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);

1290 1291 1292 1293 1294
#define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL                       \
  do {                                                                   \
    ElementwiseAddWithCastCUDAKernel<T1, T2, T3, kVecSize>               \
        <<<config.block_per_grid, config.thread_per_block, 0, stream>>>( \
            x, y, z, n);                                                 \
1295 1296 1297 1298 1299 1300
  } while (0)

  PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL);
#undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL
}

1301
template <typename T>
L
Leo Chen 已提交
1302
class DistributedFusedLambOpKernel<phi::GPUContext, T>
1303 1304 1305 1306
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
L
Leo Chen 已提交
1307
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
1308 1309 1310
    auto stream = dev_ctx.stream();
    auto place = dev_ctx.GetPlace();

1311 1312 1313
    auto *found_inf_t = ctx.Output<framework::Tensor>("FoundInf");
    found_inf_t->Resize({1});

1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329
    // Step 1: Get fp16 param and grad tensors
    int64_t fp16_numel;
    auto *fp16_param = GetSameInOutTensorPtr<platform::float16, true>(
        ctx, place, "FP16FusedParam", "FP16FusedParamOut", &fp16_numel);
    bool has_fp16_param = (fp16_numel > 0);
    const platform::float16 *fp16_grad = nullptr;
    if (has_fp16_param) {
      fp16_grad = GetInputTensorPtr<platform::float16>(ctx, "FP16FusedGrad");
    } else {
      fp16_param = nullptr;
    }

    // Step 2: Get fp32 param and grad tensors
    int64_t fp32_numel = 0;
    auto *fp32_param = GetSameInOutTensorPtr<float, true>(
        ctx, place, "FP32FusedParam", "FP32FusedParamOut", &fp32_numel);
1330 1331
    PADDLE_ENFORCE_GE(fp32_numel,
                      fp16_numel,
1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343
                      platform::errors::InvalidArgument(
                          "The element number in FP32FusedParam should be not "
                          "less than FP16FusedParam."));

    fp32_numel -= fp16_numel;  // the FP32FusedParam contains fp32 param and
                               // fp16 master weight
    bool has_fp32_param = (fp32_numel > 0);
    const float *fp32_grad = nullptr;
    if (has_fp32_param) {
      fp32_grad = GetInputTensorPtr<float>(ctx, "FP32FusedGrad");
    } else {
      PADDLE_ENFORCE_EQ(
1344 1345
          has_fp16_param,
          true,
1346 1347 1348 1349 1350 1351 1352 1353 1354
          platform::errors::InvalidArgument(
              "Either FP32FusedGrad or FP16FusedGrad cannot be NULL."));
    }

    auto numel = fp32_numel + fp16_numel;
    VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel
            << " , fp16_numel = " << fp16_numel;

    // The NVIDIA cub library does not support number > INT32_MAX
1355 1356
    PADDLE_ENFORCE_LE(numel,
                      std::numeric_limits<int>::max(),
1357 1358 1359 1360
                      platform::errors::Unimplemented(
                          "Too many parameter number. Only <= %d is supported.",
                          std::numeric_limits<int>::max()));

1361 1362
    auto acc_steps = ctx.Attr<int>("acc_steps");
    PADDLE_ENFORCE_GE(
1363 1364
        acc_steps,
        1,
1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389
        platform::errors::InvalidArgument(
            "The gradient accumulation steps should be not less than 1."));
    if (acc_steps > 1) {
      auto *step_t = ctx.Output<framework::Tensor>("AccStep");
      PADDLE_ENFORCE_NOT_NULL(
          step_t,
          platform::errors::InvalidArgument(
              "Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
      bool is_initialized = step_t->IsInitialized();
      int64_t *step_ptr;
      if (is_initialized) {
        step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
        ++(*step_ptr);
      } else {
        step_t->Resize({1});
        step_ptr = step_t->mutable_data<int64_t>(platform::CPUPlace());
        *step_ptr = 1;
      }
      int64_t rounded_step = (*step_ptr) % acc_steps;

      float *fp32_acc_grad = nullptr;
      if (has_fp32_param) {
        auto *fp32_acc_grad_t =
            ctx.Output<framework::Tensor>("FP32AccFusedGrad");
        PADDLE_ENFORCE_NOT_NULL(
1390 1391 1392 1393
            fp32_acc_grad_t,
            platform::errors::InvalidArgument(
                "Output(FP32AccFusedGrad) cannot be nullptr "
                "when Attr(acc_steps) > 1."));
1394 1395 1396 1397 1398 1399 1400 1401 1402 1403
        if (!fp32_acc_grad_t->IsInitialized()) {
          fp32_acc_grad_t->Resize({static_cast<int64_t>(fp32_numel)});
          fp32_acc_grad = fp32_acc_grad_t->mutable_data<float>(place);
        } else {
          fp32_acc_grad = fp32_acc_grad_t->data<float>();
        }
      }

      platform::float16 *fp16_acc_grad = nullptr;
      float *master_acc_grad = nullptr;
1404
      bool use_master_acc_grad = false;
1405
      if (has_fp16_param) {
1406
        use_master_acc_grad = ctx.Attr<bool>("use_master_acc_grad");
1407 1408 1409
        auto *fp16_acc_grad_t =
            ctx.Output<framework::Tensor>("FP16AccFusedGrad");
        PADDLE_ENFORCE_NOT_NULL(
1410 1411 1412 1413
            fp16_acc_grad_t,
            platform::errors::InvalidArgument(
                "Output(FP16AccFusedGrad) cannot be nullptr "
                "when Attr(acc_steps) > 1."));
1414
        if (!fp16_acc_grad_t->IsInitialized()) {
1415 1416 1417
          auto acc_grad_size =
              use_master_acc_grad ? (3 * fp16_numel) : fp16_numel;
          fp16_acc_grad_t->Resize({static_cast<int64_t>(acc_grad_size)});
1418 1419 1420 1421 1422
          fp16_acc_grad =
              fp16_acc_grad_t->mutable_data<platform::float16>(place);
        } else {
          fp16_acc_grad = fp16_acc_grad_t->data<platform::float16>();
        }
1423 1424 1425 1426
        if (use_master_acc_grad) {
          master_acc_grad =
              reinterpret_cast<float *>(fp16_acc_grad + fp16_numel);
        }
1427 1428 1429 1430 1431
      }

      // Inplace addto
      if (has_fp32_param) {
        if (rounded_step == 1) {
1432 1433 1434 1435 1436 1437
          memory::Copy(place,
                       fp32_acc_grad,
                       place,
                       fp32_grad,
                       fp32_numel * sizeof(float),
                       stream);
1438
        } else {
1439 1440 1441 1442 1443 1444
          LaunchElementwiseAddWithCastKernel(dev_ctx,
                                             fp32_grad,
                                             fp32_acc_grad,
                                             fp32_acc_grad,
                                             fp32_numel,
                                             stream);
1445 1446 1447 1448
        }
      }

      if (has_fp16_param) {
1449 1450
        if (acc_steps == 2 || !use_master_acc_grad) {
          if (rounded_step != 1) {
1451 1452 1453 1454 1455 1456
            LaunchElementwiseAddWithCastKernel(dev_ctx,
                                               fp16_acc_grad,
                                               fp16_grad,
                                               fp16_acc_grad,
                                               fp16_numel,
                                               stream);
1457
          } else {
1458 1459 1460 1461 1462 1463
            memory::Copy(place,
                         fp16_acc_grad,
                         place,
                         fp16_grad,
                         fp16_numel * sizeof(platform::float16),
                         stream);
1464 1465 1466
          }
        } else {  // acc_steps >= 3
          if (rounded_step == 0) {
1467 1468 1469 1470 1471 1472
            LaunchElementwiseAddWithCastKernel(dev_ctx,
                                               fp16_grad,
                                               master_acc_grad,
                                               fp16_acc_grad,
                                               fp16_numel,
                                               stream);
1473
          } else if (rounded_step == 1) {
1474 1475 1476 1477 1478 1479
            memory::Copy(place,
                         fp16_acc_grad,
                         place,
                         fp16_grad,
                         fp16_numel * sizeof(platform::float16),
                         stream);
1480
          } else if (rounded_step == 2) {
1481 1482 1483 1484 1485 1486
            LaunchElementwiseAddWithCastKernel(dev_ctx,
                                               fp16_grad,
                                               fp16_acc_grad,
                                               master_acc_grad,
                                               fp16_numel,
                                               stream);
1487
          } else {
1488 1489 1490 1491 1492 1493
            LaunchElementwiseAddWithCastKernel(dev_ctx,
                                               fp16_grad,
                                               master_acc_grad,
                                               master_acc_grad,
                                               fp16_numel,
                                               stream);
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520
          }
        }
      }

      auto *stop_update_t = ctx.Output<framework::Tensor>("StopUpdate");
      stop_update_t->Resize({1});
      auto *stop_update =
          stop_update_t->mutable_data<bool>(platform::CPUPlace());

      auto *found_inf_cpu =
          found_inf_t->mutable_data<bool>(platform::CPUPlace());

      if (rounded_step != 0) {
        *stop_update = true;
        auto *found_inf_cpu =
            found_inf_t->mutable_data<bool>(platform::CPUPlace());
        *found_inf_cpu = false;
        return;
      } else {
        // swap pointer
        fp32_grad = fp32_acc_grad;
        fp16_grad = fp16_acc_grad;
        *stop_update = false;
        found_inf_t->clear();
      }
    }

1521
    // Step 3: Get ParamInfo
1522 1523 1524 1525
    const auto *param_info_tensor = GetInputTensorPtr<int>(ctx, "ParamInfo");
    auto fp32_local_start_idx = param_info_tensor[0];
    auto fp32_local_param_num = param_info_tensor[1];
    auto fp32_global_param_num = param_info_tensor[2];
1526 1527 1528 1529 1530
    auto fp32_weight_decay_end_idx = param_info_tensor[3];
    auto fp16_local_start_idx = param_info_tensor[4];
    auto fp16_local_param_num = param_info_tensor[5];
    auto fp16_global_param_num = param_info_tensor[6];
    auto fp16_weight_decay_end_idx = param_info_tensor[7];
1531 1532 1533

    auto local_param_num = fp32_local_param_num + fp16_local_param_num;
    auto param_num = fp32_global_param_num + fp16_global_param_num;
1534 1535
    PADDLE_ENFORCE_LE(local_param_num,
                      param_num,
1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548
                      platform::errors::InvalidArgument(
                          "The local parameter number should not exceed the "
                          "global parameter number."));
    VLOG(1) << "local_param_num = " << local_param_num
            << " , global_param_num = " << param_num
            << " , fp32_local_start_idx = " << fp32_local_start_idx
            << " , fp32_local_param_num = " << fp32_local_param_num
            << " , fp32_global_param_num = " << fp32_global_param_num
            << " , fp16_local_start_idx = " << fp16_local_start_idx
            << " , fp16_local_param_num = " << fp16_local_param_num
            << " , fp16_global_param_num = " << fp16_global_param_num;

    // Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow,
1549
    // GlobalScale
1550 1551 1552
    const auto *global_scale = GetInputTensorPtr<float>(ctx, "GlobalScale");
    const auto *lr = GetInputTensorPtr<float>(ctx, "LearningRate");
    int64_t partial_numel = 0;
1553 1554
    auto *moment1 = GetSameInOutTensorPtr<float>(
        ctx, place, "Moment1", "Moment1Out", &partial_numel);
1555

1556 1557
    PADDLE_ENFORCE_EQ(numel % partial_numel,
                      0,
1558 1559 1560
                      platform::errors::InvalidArgument(
                          "The total parameter number %d should be divided "
                          "exactly by the element number %d of Moment1.",
1561 1562
                          numel,
                          partial_numel));
1563

1564 1565 1566
    // The num_devices means the number of devices that shard a complete set
    // of all parameters. It may be num_devices < nranks or num_devices ==
    // nranks.
1567 1568 1569 1570
    int64_t num_devices = numel / partial_numel;
    VLOG(1) << "num_devices = " << num_devices
            << " , partial_numel = " << partial_numel;

1571 1572
    PADDLE_ENFORCE_EQ(fp32_numel % num_devices,
                      0,
1573 1574 1575
                      platform::errors::InvalidArgument(
                          "The fp32 parameter number %d should be divided "
                          "exactly by the device number %d.",
1576 1577 1578 1579
                          fp32_numel,
                          num_devices));
    PADDLE_ENFORCE_EQ(fp16_numel % num_devices,
                      0,
1580 1581 1582
                      platform::errors::InvalidArgument(
                          "The fp16 parameter number %d should be divided "
                          "exactly by the device number %d.",
1583 1584
                          fp16_numel,
                          num_devices));
1585 1586 1587 1588 1589 1590 1591 1592 1593 1594

    auto *moment2 =
        GetSameInOutTensorPtr<float>(ctx, place, "Moment2", "Moment2Out");
    auto *beta1pow =
        GetSameInOutTensorPtr<float>(ctx, place, "Beta1Pow", "Beta1PowOut");
    auto *beta2pow =
        GetSameInOutTensorPtr<float>(ctx, place, "Beta2Pow", "Beta2PowOut");

    auto *found_inf = found_inf_t->mutable_data<bool>(place);

1595 1596
    // Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
    // max_grad_norm, ring_id,
1597
    // use_master_param_norm, is_grad_scaled_by_nranks
1598
    auto weight_decay = ctx.Attr<float>("weight_decay");
1599 1600 1601 1602 1603
    auto beta1 = ctx.Attr<float>("beta1");
    auto beta2 = ctx.Attr<float>("beta2");
    auto epsilon = ctx.Attr<float>("epsilon");
    auto max_global_grad_norm = ctx.Attr<float>("max_global_grad_norm");
    auto clip_after_allreduce = ctx.Attr<bool>("clip_after_allreduce");
1604
    auto nranks = ctx.Attr<int64_t>("nranks");
1605 1606
    PADDLE_ENFORCE_GE(nranks,
                      num_devices,
1607 1608 1609
                      phi::errors::InvalidArgument(
                          "The nranks must be not less than num_devices."));
    PADDLE_ENFORCE_EQ(
1610 1611
        nranks % num_devices,
        0,
1612 1613 1614 1615 1616
        phi::errors::InvalidArgument(
            "The nranks must be exactly divided by num_devices."));
    bool local_shard = (nranks > num_devices);

    const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id");
1617 1618
    auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
    auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
1619 1620
    auto use_hierarchical_allreduce =
        ctx.Attr<bool>("use_hierarchical_allreduce");
1621 1622 1623
    VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
             << " , clip_after_allreduce = " << clip_after_allreduce
             << " , use_master_param_norm = " << use_master_param_norm
1624
             << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
1625 1626 1627
             << " , local_shard = " << local_shard
             << " , use_hierarchical_allreduce = "
             << use_hierarchical_allreduce;
1628 1629

    // Step 6: allreduce + global norm gradient clip
1630
    int64_t global_rank = 0, local_rank = 0;
1631 1632
    ncclComm_t global_comm = nullptr, local_comm = nullptr,
               external_comm = nullptr;
1633
    if (nranks > 1) {
1634
      auto *nccl_comm_handle =
1635 1636 1637 1638 1639 1640 1641 1642 1643
          platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
      global_comm = nccl_comm_handle->comm();
      global_rank = nccl_comm_handle->rank();

      if (local_shard) {
        auto *local_nccl_comm_handle =
            platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
        local_comm = local_nccl_comm_handle->comm();
        local_rank = local_nccl_comm_handle->rank();
1644 1645 1646 1647 1648
        if (use_hierarchical_allreduce) {
          external_comm = platform::NCCLCommContext::Instance()
                              .Get(ring_ids[2], place)
                              ->comm();
        }
1649 1650 1651 1652
      } else {
        local_comm = global_comm;
        local_rank = global_rank;
      }
1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663
    }

    memory::Buffer grad_norm_square_buffer(place);
    auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
    memory::Buffer cub_tmp_buffer(place);

    memory::Buffer sum_grad_buffer(place);
    float *fp32_sum_grad;
    platform::float16 *fp16_sum_grad;
    auto fp32_numel_each_device = fp32_numel / num_devices;
    auto fp16_numel_each_device = fp16_numel / num_devices;
1664 1665 1666 1667 1668 1669 1670 1671 1672
    if (local_shard) {
      auto ptr = sum_grad_buffer.Alloc<uint8_t>(
          fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16));
      fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
      fp16_sum_grad = has_fp16_param ? reinterpret_cast<platform::float16 *>(
                                           ptr + fp32_numel * sizeof(float))
                                     : nullptr;
    } else if (nranks > 1 ||
               (max_global_grad_norm > 0 && !clip_after_allreduce)) {
1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693
      auto ptr = sum_grad_buffer.Alloc<uint8_t>(
          fp32_numel_each_device * sizeof(float) +
          fp16_numel_each_device * sizeof(platform::float16));
      fp32_sum_grad = has_fp32_param ? reinterpret_cast<float *>(ptr) : nullptr;
      fp16_sum_grad = has_fp16_param
                          ? reinterpret_cast<platform::float16 *>(
                                ptr + fp32_numel_each_device * sizeof(float))
                          : nullptr;
    } else {
      // NOTE: The const_cast here is not important. The fp32_sum_grad and
      // fp16_sum_grad would not be changed when num_devices == 1
      // But if I do not perform const_cast here, there would be more
      // if-else codes (num_devices > 1) when I write the following code.
      // So I prefer to use const_cast to unify the following code to reduce
      // the if-else codes.
      fp32_sum_grad = const_cast<float *>(fp32_grad);
      fp16_sum_grad = const_cast<platform::float16 *>(fp16_grad);
    }

    float rescale_grad = 1.0f;
    if (!is_grad_scaled_by_nranks) {
1694
      rescale_grad /= nranks;
1695 1696 1697 1698 1699
    }

    if (max_global_grad_norm > 0) {
      if (clip_after_allreduce) {
        // (1) ReduceScater first
1700
        if (local_shard) {
1701 1702
          if (use_hierarchical_allreduce) {
            NCCLReduceScatterWithScale(
S
sneaxiy 已提交
1703
                fp32_grad,
1704 1705 1706 1707 1708 1709
                fp32_sum_grad + local_rank * fp32_numel_each_device,
                fp32_numel_each_device,
                num_devices,
                local_comm,
                stream,
                dev_ctx);
S
sneaxiy 已提交
1710 1711 1712 1713 1714 1715 1716 1717
            NCCLAllReduceWithScale(
                fp32_sum_grad + local_rank * fp32_numel_each_device,
                fp32_sum_grad + local_rank * fp32_numel_each_device,
                fp32_numel_each_device,
                nranks / num_devices,
                external_comm,
                stream,
                dev_ctx);
1718 1719

            NCCLReduceScatterWithScale(
S
sneaxiy 已提交
1720
                fp16_grad,
1721 1722 1723 1724 1725 1726
                fp16_sum_grad + local_rank * fp16_numel_each_device,
                fp16_numel_each_device,
                num_devices,
                local_comm,
                stream,
                dev_ctx);
S
sneaxiy 已提交
1727 1728 1729 1730 1731 1732 1733 1734
            NCCLAllReduceWithScale(
                fp16_sum_grad + local_rank * fp16_numel_each_device,
                fp16_sum_grad + local_rank * fp16_numel_each_device,
                fp16_numel_each_device,
                nranks / num_devices,
                external_comm,
                stream,
                dev_ctx);
1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750
          } else {
            NCCLAllReduceWithScale(fp32_grad,
                                   fp32_sum_grad,
                                   fp32_numel,
                                   nranks,
                                   global_comm,
                                   stream,
                                   dev_ctx);
            NCCLAllReduceWithScale(fp16_grad,
                                   fp16_sum_grad,
                                   fp16_numel,
                                   nranks,
                                   global_comm,
                                   stream,
                                   dev_ctx);
          }
1751 1752 1753
          fp32_sum_grad += (local_rank * fp32_numel_each_device);
          fp16_sum_grad += (local_rank * fp16_numel_each_device);
        } else {
1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767
          NCCLReduceScatterWithScale(fp32_grad,
                                     fp32_sum_grad,
                                     fp32_numel_each_device,
                                     nranks,
                                     global_comm,
                                     stream,
                                     dev_ctx);
          NCCLReduceScatterWithScale(fp16_grad,
                                     fp16_sum_grad,
                                     fp16_numel_each_device,
                                     nranks,
                                     global_comm,
                                     stream,
                                     dev_ctx);
1768
        }
1769
        // (2) Calculate the global grad norm
1770 1771 1772 1773 1774 1775
        GetSquareGradNorm(fp32_sum_grad,
                          fp32_numel_each_device,
                          fp16_sum_grad,
                          fp16_numel_each_device,
                          fp32_square_grad_norm,
                          stream,
1776 1777 1778 1779
                          &cub_tmp_buffer);
        VLOG(1) << "Grad square norm before all reduce: "
                << FlattenToString(fp32_square_grad_norm, 1, place);
        if (num_devices > 1) {
1780 1781 1782 1783 1784 1785 1786 1787
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::ncclAllReduce(fp32_square_grad_norm,
                                               fp32_square_grad_norm,
                                               1,
                                               ncclFloat32,
                                               ncclSum,
                                               local_comm,
                                               stream));
1788 1789 1790 1791 1792
        }
        VLOG(1) << "Grad square norm after all reduce: "
                << FlattenToString(fp32_square_grad_norm, 1, place);
      } else {
        // (1) Calculate the local grad norm
1793 1794 1795 1796 1797 1798 1799
        GetSquareGradNorm(fp32_grad,
                          fp32_numel,
                          fp16_grad,
                          fp16_numel,
                          fp32_square_grad_norm,
                          stream,
                          &cub_tmp_buffer);
1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818
        VLOG(1) << "Grad square norm before all reduce: "
                << FlattenToString(fp32_square_grad_norm, 1, place);
        // (2) Calculate the gradient clip scale
        float *fp32_scale = nullptr;
        platform::float16 *fp16_scale = nullptr;
        if (has_fp32_param && has_fp16_param) {
          auto *ptr = cub_tmp_buffer.Alloc<uint8_t>(sizeof(float) +
                                                    sizeof(platform::float16));
          fp32_scale = reinterpret_cast<float *>(ptr);
          fp16_scale =
              reinterpret_cast<platform::float16 *>(ptr + sizeof(float));
        } else if (has_fp32_param) {
          fp32_scale = cub_tmp_buffer.Alloc<float>(1);
        } else {
          fp16_scale = cub_tmp_buffer.Alloc<platform::float16>(1);
        }

        float clip_scale = 1.0f;
        if (is_grad_scaled_by_nranks) {
1819
          clip_scale *= nranks;
1820
        }
1821
        CalcGradNormClipBeforeAllReduceScale<float, platform::float16>
1822 1823 1824 1825 1826
            <<<1, 1, 0, stream>>>(global_scale,
                                  max_global_grad_norm,
                                  fp32_square_grad_norm,
                                  fp32_scale,
                                  fp16_scale,
1827
                                  clip_scale);
1828 1829 1830 1831 1832
        if (fp32_scale) {
          VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place);
        } else {
          VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place);
        }
1833
        if (nranks > 1) {
1834 1835 1836 1837 1838 1839 1840 1841
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::ncclAllReduce(fp32_square_grad_norm,
                                               fp32_square_grad_norm,
                                               1,
                                               ncclFloat32,
                                               ncclSum,
                                               global_comm,
                                               stream));
1842 1843
        }
        // (3) Do ReduceScatter with scale
1844
        if (local_shard) {
1845 1846
          if (use_hierarchical_allreduce) {
            NCCLReduceScatterWithScale(
S
sneaxiy 已提交
1847
                fp32_grad,
1848 1849 1850 1851 1852
                fp32_sum_grad + local_rank * fp32_numel_each_device,
                fp32_numel_each_device,
                num_devices,
                local_comm,
                stream,
S
sneaxiy 已提交
1853 1854 1855 1856 1857 1858 1859 1860 1861
                dev_ctx,
                fp32_scale);
            NCCLAllReduceWithScale(
                fp32_sum_grad + local_rank * fp32_numel_each_device,
                fp32_sum_grad + local_rank * fp32_numel_each_device,
                fp32_numel_each_device,
                nranks / num_devices,
                external_comm,
                stream,
1862 1863 1864
                dev_ctx);

            NCCLReduceScatterWithScale(
S
sneaxiy 已提交
1865
                fp16_grad,
1866 1867 1868 1869 1870
                fp16_sum_grad + local_rank * fp16_numel_each_device,
                fp16_numel_each_device,
                num_devices,
                local_comm,
                stream,
S
sneaxiy 已提交
1871 1872 1873 1874 1875 1876 1877 1878 1879
                dev_ctx,
                fp16_scale);
            NCCLAllReduceWithScale(
                fp16_sum_grad + local_rank * fp16_numel_each_device,
                fp16_sum_grad + local_rank * fp16_numel_each_device,
                fp16_numel_each_device,
                nranks / num_devices,
                external_comm,
                stream,
1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898
                dev_ctx);
          } else {
            NCCLAllReduceWithScale(fp32_grad,
                                   fp32_sum_grad,
                                   fp32_numel,
                                   nranks,
                                   global_comm,
                                   stream,
                                   dev_ctx,
                                   fp32_scale);
            NCCLAllReduceWithScale(fp16_grad,
                                   fp16_sum_grad,
                                   fp16_numel,
                                   nranks,
                                   global_comm,
                                   stream,
                                   dev_ctx,
                                   fp16_scale);
          }
1899 1900 1901
          fp32_sum_grad += (local_rank * fp32_numel_each_device);
          fp16_sum_grad += (local_rank * fp16_numel_each_device);
        } else {
1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917
          NCCLReduceScatterWithScale(fp32_grad,
                                     fp32_sum_grad,
                                     fp32_numel_each_device,
                                     nranks,
                                     global_comm,
                                     stream,
                                     dev_ctx,
                                     fp32_scale);
          NCCLReduceScatterWithScale(fp16_grad,
                                     fp16_sum_grad,
                                     fp16_numel_each_device,
                                     nranks,
                                     global_comm,
                                     stream,
                                     dev_ctx,
                                     fp16_scale);
1918
        }
1919 1920 1921 1922 1923
        // (4) mark max_global_grad_norm as 0, meaning that clip has been
        // already performed
        max_global_grad_norm = 0;
      }
    } else {
1924
      if (local_shard) {
1925 1926
        if (use_hierarchical_allreduce) {
          NCCLReduceScatterWithScale(
S
sneaxiy 已提交
1927
              fp32_grad,
1928 1929 1930 1931 1932 1933
              fp32_sum_grad + local_rank * fp32_numel_each_device,
              fp32_numel_each_device,
              num_devices,
              local_comm,
              stream,
              dev_ctx);
S
sneaxiy 已提交
1934 1935 1936 1937 1938 1939 1940 1941
          NCCLAllReduceWithScale(
              fp32_sum_grad + local_rank * fp32_numel_each_device,
              fp32_sum_grad + local_rank * fp32_numel_each_device,
              fp32_numel_each_device,
              nranks / num_devices,
              external_comm,
              stream,
              dev_ctx);
1942 1943

          NCCLReduceScatterWithScale(
S
sneaxiy 已提交
1944
              fp16_grad,
1945 1946 1947 1948 1949 1950
              fp16_sum_grad + local_rank * fp16_numel_each_device,
              fp16_numel_each_device,
              num_devices,
              local_comm,
              stream,
              dev_ctx);
S
sneaxiy 已提交
1951 1952 1953 1954 1955 1956 1957 1958
          NCCLAllReduceWithScale(
              fp16_sum_grad + local_rank * fp16_numel_each_device,
              fp16_sum_grad + local_rank * fp16_numel_each_device,
              fp16_numel_each_device,
              nranks / num_devices,
              external_comm,
              stream,
              dev_ctx);
1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974
        } else {
          NCCLAllReduceWithScale(fp32_grad,
                                 fp32_sum_grad,
                                 fp32_numel,
                                 nranks,
                                 global_comm,
                                 stream,
                                 dev_ctx);
          NCCLAllReduceWithScale(fp16_grad,
                                 fp16_sum_grad,
                                 fp16_numel,
                                 nranks,
                                 global_comm,
                                 stream,
                                 dev_ctx);
        }
1975 1976 1977
        fp32_sum_grad += (local_rank * fp32_numel_each_device);
        fp16_sum_grad += (local_rank * fp16_numel_each_device);
      } else {
1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991
        NCCLReduceScatterWithScale(fp32_grad,
                                   fp32_sum_grad,
                                   fp32_numel_each_device,
                                   num_devices,
                                   global_comm,
                                   stream,
                                   dev_ctx);
        NCCLReduceScatterWithScale(fp16_grad,
                                   fp16_sum_grad,
                                   fp16_numel_each_device,
                                   num_devices,
                                   global_comm,
                                   stream,
                                   dev_ctx);
1992
      }
1993 1994 1995 1996 1997 1998
      CheckHasNanInfGrad(fp32_sum_grad,
                         fp32_numel_each_device,
                         fp16_sum_grad,
                         fp16_numel_each_device,
                         fp32_square_grad_norm,
                         stream,
1999 2000
                         &cub_tmp_buffer);
      if (num_devices > 1) {
2001 2002 2003 2004 2005 2006 2007 2008
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::ncclAllReduce(fp32_square_grad_norm,
                                             fp32_square_grad_norm,
                                             1,
                                             ncclFloat32,
                                             ncclSum,
                                             local_comm,
                                             stream));
2009 2010 2011 2012 2013 2014
      }
      max_global_grad_norm = 0;
    }
    VLOG(10) << "ReduceScatter done";

    // Step 7: update the moment1, moment2. Calcuate the trust_ratio_div
2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025
    auto *fused_offsets_t = ctx.Input<framework::Tensor>("FusedParamOffsets");
    auto *fused_offsets = fused_offsets_t->data<int>();
    auto *fp32_partial_fused_offsets_t =
        ctx.Input<framework::Tensor>("FP32ShardFusedParamOffsets");
    const auto *fp32_partial_fused_offsets =
        fp32_partial_fused_offsets_t->data<int>();
    auto *fp16_partial_fused_offsets_t =
        ctx.Input<framework::Tensor>("FP16ShardFusedParamOffsets");
    const auto *fp16_partial_fused_offsets =
        fp16_partial_fused_offsets_t->data<int>();

2026 2027
    auto *step = ctx.Output<framework::Tensor>("Step")->data<int64_t>();

2028
    VLOG(1) << "FusedParamOffsets: "
2029 2030
            << FlattenToString(fused_offsets,
                               fused_offsets_t->numel(),
2031 2032 2033 2034 2035 2036 2037 2038 2039 2040
                               fused_offsets_t->place());
    VLOG(1) << "FP32ShardFusedParamOffsets: "
            << FlattenToString(fp32_partial_fused_offsets,
                               fp32_partial_fused_offsets_t->numel(),
                               fp32_partial_fused_offsets_t->place());
    VLOG(1) << "FP16ShardFusedParamOffsets: "
            << FlattenToString(fp16_partial_fused_offsets,
                               fp16_partial_fused_offsets_t->numel(),
                               fp16_partial_fused_offsets_t->place());

2041 2042
    memory::Buffer trust_ratio_div_buffer(place);
    auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
2043 2044
    auto fp32_offset = local_rank * fp32_numel_each_device;
    auto fp16_offset = local_rank * fp16_numel_each_device;
2045 2046
    if (has_fp32_param) {
      VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts";
2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067
      MultiTensorUpdateLambMomentAndTrustRatioDiv(dev_ctx,
                                                  fp32_partial_fused_offsets,
                                                  fp32_local_param_num,
                                                  fp32_param + fp32_offset,
                                                  fp32_sum_grad,
                                                  fp32_square_grad_norm,
                                                  global_scale,
                                                  beta1pow,
                                                  beta2pow,
                                                  moment1,
                                                  moment2,
                                                  trust_ratio_div,
                                                  found_inf,
                                                  step,
                                                  weight_decay,
                                                  fp32_weight_decay_end_idx,
                                                  beta1,
                                                  beta2,
                                                  epsilon,
                                                  max_global_grad_norm,
                                                  rescale_grad);
2068 2069 2070 2071 2072 2073
      VLOG(10) << "Update FP32 Moment and TrustRatioDiv done";
    }
    float *master_param = nullptr;
    if (has_fp16_param) {
      master_param = fp32_param + fp32_numel;
      VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts";
2074
      auto tmp_found_inf = has_fp32_param ? nullptr : found_inf;
2075
      auto tmp_step = has_fp32_param ? nullptr : step;
2076
      MultiTensorUpdateLambMomentAndTrustRatioDiv(
2077 2078 2079 2080 2081 2082 2083 2084 2085 2086
          dev_ctx,
          fp16_partial_fused_offsets,
          fp16_local_param_num,
          master_param + fp16_offset,
          fp16_sum_grad,
          fp32_square_grad_norm,
          global_scale,
          beta1pow,
          beta2pow,
          moment1 + fp32_numel_each_device,
2087
          moment2 + fp32_numel_each_device,
2088 2089 2090 2091 2092 2093 2094 2095 2096 2097
          trust_ratio_div + fp32_numel_each_device,
          tmp_found_inf,
          tmp_step,
          weight_decay,
          fp16_weight_decay_end_idx,
          beta1,
          beta2,
          epsilon,
          max_global_grad_norm,
          rescale_grad);
2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109
      VLOG(10) << "Update FP16 Moment and TrustRatioDiv done";
    }

    VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha";

    // Step 8: calculate L2-Norm square of parameter and trust_ratio_div
    memory::Buffer square_norm_buffer(place);
    auto *param_square_norm = square_norm_buffer.Alloc<float>(2 * param_num);
    auto *trust_ratio_div_square_norm = param_square_norm + param_num;
    if (num_devices > 1) {
      if (use_master_param_norm) {
        FillZeroWithPtr(param_square_norm + fp32_global_param_num,
2110 2111
                        2 * param_num - fp32_global_param_num,
                        stream);
2112 2113 2114 2115
      } else {
        FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
      }
    }
2116 2117 2118 2119 2120 2121
    MultiTensorL2Norm(place,
                      stream,
                      fp32_param,
                      fused_offsets,
                      fp32_global_param_num,
                      param_square_norm);
2122
    if (use_master_param_norm) {
2123 2124 2125 2126 2127
      MultiTensorL2Norm(place,
                        stream,
                        master_param + fp16_offset,
                        fp16_partial_fused_offsets,
                        fp16_local_param_num,
2128
                        param_square_norm + fp16_local_start_idx);
2129
    } else {
2130 2131
      MultiTensorL2Norm(place,
                        stream,
2132 2133 2134 2135 2136
                        fp16_param + fused_offsets[fp16_local_start_idx] -
                            fused_offsets[fp32_global_param_num],
                        fused_offsets + fp16_local_start_idx,
                        fp16_local_param_num,
                        param_square_norm + fp16_local_start_idx);
2137 2138
    }

2139 2140 2141 2142 2143
    MultiTensorL2Norm(place,
                      stream,
                      trust_ratio_div,
                      fp32_partial_fused_offsets,
                      fp32_local_param_num,
2144
                      trust_ratio_div_square_norm + fp32_local_start_idx);
2145 2146 2147 2148 2149
    MultiTensorL2Norm(place,
                      stream,
                      trust_ratio_div + fp32_numel_each_device,
                      fp16_partial_fused_offsets,
                      fp16_local_param_num,
2150
                      trust_ratio_div_square_norm + fp16_local_start_idx);
2151 2152 2153 2154 2155 2156 2157 2158

    VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
            << FlattenToString(trust_ratio_div_square_norm, param_num, place);
    if (num_devices > 1) {
      if (use_master_param_norm) {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
            param_square_norm + fp32_global_param_num,
            param_square_norm + fp32_global_param_num,
2159 2160 2161 2162 2163
            2 * param_num - fp32_global_param_num,
            ncclFloat32,
            ncclSum,
            local_comm,
            stream));
2164
      } else {
2165 2166 2167 2168 2169 2170 2171 2172
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::ncclAllReduce(trust_ratio_div_square_norm,
                                             trust_ratio_div_square_norm,
                                             param_num,
                                             ncclFloat32,
                                             ncclSum,
                                             local_comm,
                                             stream));
2173 2174 2175 2176
      }
      VLOG(10) << "ncclAllReduce done";
    }

2177 2178
    LogParamAndTrustRatioDivSquareNorm<1>(
        ctx, param_square_norm, trust_ratio_div_square_norm);
2179 2180 2181 2182
    VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done";

    // Step 9: update parameter, beta1pow, beta2pow. All gather parameters.
    if (has_fp32_param) {
2183
      MultiTensorUpdateLambParamAndBetaPows<float>(
2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197
          dev_ctx,
          fp32_partial_fused_offsets,
          fp32_local_param_num,
          trust_ratio_div,
          lr,
          param_square_norm + fp32_local_start_idx,
          trust_ratio_div_square_norm + fp32_local_start_idx,
          found_inf,
          fp32_param + fp32_offset,
          nullptr,
          beta1pow,
          beta2pow,
          beta1,
          beta2);
2198 2199
      if (num_devices > 1) {
        // ncclAllGather
2200 2201 2202 2203 2204 2205 2206
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::ncclAllGather(fp32_param + fp32_offset,
                                             fp32_param,
                                             fp32_numel_each_device,
                                             ncclFloat32,
                                             local_comm,
                                             stream));
2207
      }
2208 2209 2210

      beta1pow = nullptr;
      beta2pow = nullptr;
2211 2212
    }
    if (has_fp16_param) {
2213
      MultiTensorUpdateLambParamAndBetaPows<platform::float16>(
2214 2215 2216 2217 2218
          dev_ctx,
          fp16_partial_fused_offsets,
          fp16_local_param_num,
          trust_ratio_div + fp32_numel_each_device,
          lr,
2219
          param_square_norm + fp16_local_start_idx,
2220 2221 2222 2223 2224 2225 2226 2227
          trust_ratio_div_square_norm + fp16_local_start_idx,
          found_inf,
          fp16_param + fp16_offset,
          master_param + fp16_offset,
          beta1pow,
          beta2pow,
          beta1,
          beta2);
2228 2229
      if (num_devices > 1) {
        // ncclAllGather
2230 2231 2232 2233 2234 2235 2236
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::ncclAllGather(fp16_param + fp16_offset,
                                             fp16_param,
                                             fp16_numel_each_device,
                                             ncclFloat16,
                                             local_comm,
                                             stream));
2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256
      }
    }
    VLOG(10) << "Update Param done";

    VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm);
#else
    PADDLE_THROW(platform::errors::Unimplemented(
        "distributed_fused_lamb op should be used with NCCL/RCCL."));
#endif
  }
};

}  // namespace operators
}  // namespace paddle

namespace plat = paddle::platform;
namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
    distributed_fused_lamb,
L
Leo Chen 已提交
2257
    ops::DistributedFusedLambOpKernel<phi::GPUContext, float>);