From 5df3cd611056b88bac70f30f1830906b58ce49b8 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Sat, 19 Feb 2022 07:50:39 +0800 Subject: [PATCH] Add the DistributedFusedLamb optimizer (#39148) * add DistributedFusedLamb op * polish code * fix compile error * compatible with pten changement * fix rocm compile error * improve converage * update upstream/develop * fix cast_with_ptr.h * add FLAGS_distributed_lamb_divide_nranks_when_allreduce=1 * fix clip before allreduce * add use_master_param_norm * code polish * fix bug * fix ROCM ci --- paddle/fluid/memory/buffer.h | 2 + .../operators/optimizers/cast_with_ptr.h | 76 + .../distributed_fused_lamb_init_op.cc | 123 ++ .../distributed_fused_lamb_init_op.cu | 730 +++++++++ .../distributed_fused_lamb_init_op.h | 33 + .../optimizers/distributed_fused_lamb_op.cc | 154 ++ .../optimizers/distributed_fused_lamb_op.cu | 1305 +++++++++++++++++ .../optimizers/distributed_fused_lamb_op.h | 34 + paddle/fluid/operators/optimizers/lamb_op.h | 11 + paddle/fluid/operators/tensor_to_string.h | 65 + python/paddle/fluid/clip.py | 41 +- .../contrib/mixed_precision/decorator.py | 175 ++- .../fluid/tests/unittests/CMakeLists.txt | 4 + .../distributed_fused_lamb_test_base.py | 309 ++++ ...est_distributed_fused_lamb_op_with_clip.py | 80 + ..._distributed_fused_lamb_op_without_clip.py | 28 + python/paddle/incubate/__init__.py | 1 + python/paddle/incubate/optimizer/__init__.py | 1 + .../optimizer/distributed_fused_lamb.py | 305 ++++ 19 files changed, 3418 insertions(+), 59 deletions(-) create mode 100644 paddle/fluid/operators/optimizers/cast_with_ptr.h create mode 100644 paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc create mode 100644 paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu create mode 100644 paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h create mode 100644 paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc create mode 100644 paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu create mode 100644 paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h create mode 100644 paddle/fluid/operators/tensor_to_string.h create mode 100644 python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py create mode 100644 python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py create mode 100644 python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_without_clip.py create mode 100644 python/paddle/incubate/optimizer/distributed_fused_lamb.py diff --git a/paddle/fluid/memory/buffer.h b/paddle/fluid/memory/buffer.h index 127d6357e4a..99b25ca289c 100644 --- a/paddle/fluid/memory/buffer.h +++ b/paddle/fluid/memory/buffer.h @@ -51,6 +51,8 @@ class Buffer { size_t Size() const { return allocation_ ? 0 : allocation_->size(); } + platform::Place GetPlace() const { return place_; } + private: AllocationPtr allocation_; platform::Place place_; diff --git a/paddle/fluid/operators/optimizers/cast_with_ptr.h b/paddle/fluid/operators/optimizers/cast_with_ptr.h new file mode 100644 index 00000000000..555b9ed27dd --- /dev/null +++ b/paddle/fluid/operators/optimizers/cast_with_ptr.h @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/pten/api/include/tensor.h" +#include "paddle/pten/kernels/funcs/elementwise_base.h" + +namespace paddle { +namespace operators { +namespace details { + +template +struct CastFunctor { + HOSTDEVICE OutT operator()(InT x) const { return static_cast(x); } +}; + +template +static void VecCastKernel(const platform::CUDADeviceContext &ctx, const InT *x, + OutT *y, size_t n) { + auto config = platform::GetGpuLaunchConfig1D(ctx, n, VecSize); + auto block = config.GetGridSize(); + auto thread = config.GetBlockSize(); + auto main_offset = n / (VecSize * thread) * VecSize * thread; + auto stream = ctx.stream(); + using FunctorT = CastFunctor; + pten::framework::Array in_arr; + in_arr[0] = reinterpret_cast(x); + pten::framework::Array<_ptr_ OutT *, 1> out_arr; + out_arr[0] = y; + pten::funcs::VectorizedElementwiseKernel< + OutT, FunctorT, 1, 1, VecSize><<>>( + in_arr, out_arr, n, main_offset, FunctorT()); +} + +} // namespace details + +template +static void LaunchCastKernel(const platform::CUDADeviceContext &ctx, + const InT *x, OutT *y, size_t n) { + if (n == 0) return; + PADDLE_ENFORCE_NE( + static_cast(x), static_cast(y), + platform::errors::InvalidArgument("Inplace cast is not supported yet.")); + int vec_size = + std::min(platform::GetVectorizedSize(x), platform::GetVectorizedSize(y)); + switch (vec_size) { + case 4: + return details::VecCastKernel(ctx, x, y, n); + case 2: + return details::VecCastKernel(ctx, x, y, n); + case 1: + return details::VecCastKernel(ctx, x, y, n); + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The vectorized size must be 1, 2 or 4.")); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc new file mode 100644 index 00000000000..28c6efef141 --- /dev/null +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" + +namespace paddle { +namespace operators { + +class DistributedFusedLambInitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override {} + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto dtype = framework::proto::VarType::FP32; // dtype is not important + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +class DistributedFusedLambInitOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", "The initial parameter list.").AsDuplicable(); + AddInput("Grad", "The initial gradient list.").AsDuplicable(); + + AddOutput("FP32FusedParam", + "The fp32 fused param and fp16 fused master weight tensor. Its " + "shape is [M1+M2], where M1 is the fp32 fused parameter size and " + "M2 is the fp16 fused master weight parameter size. Note that M1 " + "and M2 should be exactly divided by N (guaranteed by extra " + "padding 0), where N is the world size.") + .AsDispensable(); + AddOutput("FP32FusedGrad", "The fp32 fused grad tensor. Its shape is [M1].") + .AsDispensable(); + AddOutput("FP16FusedParam", + "The fp16 fused param tensor. Its shape is [M2].") + .AsDispensable(); + AddOutput("FP16FusedGrad", "The fp16 fused grad tensor. Its shape is [M2].") + .AsDispensable(); + + AddOutput("Moment1", + "The sharded fp32 moment1 tensor. Its shape is [(M1+M2)/N]."); + AddOutput("Moment2", + "The sharded fp32 moment2 tensor. Its shape is [(M1+M2)/N]."); + AddOutput("Beta1Pow", + "The fp32 beta1 power accumulator tensor. Its shape is [1]."); + AddOutput("Beta2Pow", + "The fp32 beta2 power accumulator tensor. Its shape is [1]."); + AddOutput("FusedIndices", + "The param index of each element in FP32FusedParam. Its shape is " + "[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...]."); + AddOutput( + "FusedParamOffsets", + "The numel offset of each parameter inside the FP32FusedParam. Its " + "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " + "+ n_2, ...]."); + AddOutput("FP32ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp32_local_param_num + 1]."); + AddOutput("FP16ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp16_local_param_num + 1]."); + AddOutput( + "WeightDecay", + "The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N]."); + AddOutput("ParamInfo", + "The param info. It should be in CPUPlace, and its shape is [6]" + "CPUPlace, and its shape is [6]. It is " + "[fp32_shard_param_start_idx, fp32_local_param_num, " + "fp32_global_param_num, fp16_shard_param_start_idx, " + "fp16_local_param_num, fp16_global_param_num]."); + + AddOutput("ParamOut", "The output parameter list.").AsDuplicable(); + AddOutput("MasterParamOut", + "The output master parameter list. It would share the memory of " + "each fp32 parameter and fp16 master parameter.") + .AsDuplicable(); + AddOutput("GradOut", "The output gradient list.").AsDuplicable(); + AddOutput("GlobalScale", + "The global scale. It is usually the scale factor for AMP."); + + AddAttr("beta1", "The initial value of Beta1Pow."); + AddAttr("beta2", "The initial value of Beta2Pow."); + AddAttr>( + "weight_decay", + "The weight decay for each parameter. Its " + "shape is equal to the global parameter number."); + AddAttr("alignment", "The alignment in bytes for the fused tensors."); + AddAttr("rank", "The global rank of the current process."); + AddAttr("nranks", "The global world size."); + AddComment( + R"DOC(The init operator for the DistributedFusedLamb optimizer.)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init, + ops::DistributedFusedLambInitOp, + ops::DistributedFusedLambInitOpMaker); + +REGISTER_OP_CPU_KERNEL( + distributed_fused_lamb_init, + ops::DistributedFusedLambInitOpKernel); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu new file mode 100644 index 00000000000..614c48ae397 --- /dev/null +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu @@ -0,0 +1,730 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/optimizers/cast_with_ptr.h" +#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" +#include "paddle/fluid/operators/tensor_to_string.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/kernels/funcs/algorithm.h" +#include "paddle/pten/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +struct ParamGradInfo { + framework::Tensor *param_t{nullptr}; + framework::Tensor *grad_t{nullptr}; + size_t idx{0}; + size_t numel{0}; + size_t numel_with_padding{0}; + size_t numel_offset{0}; +}; + +static std::ostream &operator<<(std::ostream &os, const ParamGradInfo &info) { + return os << "{Param(" << info.param_t << "),Grad(" << info.grad_t << "),idx(" + << info.idx << "),numel(" << info.numel << "),numel_with_padding(" + << info.numel_with_padding << "),numel_offset(" << info.numel_offset + << "),padding(" << info.numel_offset + info.numel_with_padding + << "-" << info.numel_offset + info.numel << "=" + << info.numel_with_padding - info.numel << ")}"; +} + +struct ParamGradInfoNumelOffsetCompFunctor { + bool operator()(const ParamGradInfo &x, const ParamGradInfo &y) const { + return x.numel_offset < y.numel_offset; + } + + bool operator()(const ParamGradInfo &x, size_t y) const { + return x.numel_offset < y; + } + + bool operator()(size_t x, const ParamGradInfo &y) const { + return x < y.numel_offset; + } + + bool operator()(size_t x, size_t y) const { return x < y; } +}; + +static size_t GetAlignSize(size_t n, size_t alignment) { + auto remainder = n % alignment; + return remainder == 0 ? n : n + alignment - remainder; +} + +// gcd(x, y) = gcd(y, x % y) +// gcd(x, 0) = x +static size_t GCD(size_t x, size_t y) { + while (y > 0) { + auto tmp = x; + x = y; + y = tmp % y; + } + return x; +} + +static size_t LCM(size_t x, size_t y) { return x / GCD(x, y) * y; } + +// Shard the ParamGradInfo list by the numel size [start_size, end_size) +// The final results should be: +// +// start_size = sum(infos[0:i].numel_with_padding) + start_numel_offset, where +// start_numel_offset <= infos[i].numel_with_padding +// +// end_size = sum(infos[0:j].numel_with_padding) + end_numel_offset, where +// end_numel_offset <= infos[j].numel_with_padding +static void GetParamGradShardInfo(const std::vector &infos, + size_t start_size, size_t end_size, + size_t *start_idx, size_t *end_idx, + size_t *start_numel_offset, + size_t *end_numel_offset) { + VLOG(10) << "NumelOffset: " + << string::join_strings(infos, ",", [](const ParamGradInfo &info) { + return info.numel_offset; + }); + VLOG(10) << "start_size = " << start_size << " , end_size = " << end_size; + + if (infos.empty()) { + PADDLE_ENFORCE_EQ(start_size, 0, platform::errors::InvalidArgument( + "start_size should be 0.")); + PADDLE_ENFORCE_EQ(end_size, 0, platform::errors::InvalidArgument( + "end_size should be 0.")); + *start_idx = 0; + *end_idx = 0; + *start_numel_offset = 0; + *end_numel_offset = 0; + return; + } + + PADDLE_ENFORCE_LT(start_size, end_size, + platform::errors::InvalidArgument( + "start_size should be less than end_size.")); + size_t n = infos.size(); + ParamGradInfoNumelOffsetCompFunctor comp; + auto i = static_cast( + std::lower_bound(infos.begin(), infos.end(), start_size, comp) - + infos.begin()); + if (i == n || infos[i].numel_offset != start_size) { + PADDLE_ENFORCE_GT( + i, 0, platform::errors::InvalidArgument( + "Cannot find suitable sharding which is between [%d, %d)", + start_size, end_size)); + --i; + } + PADDLE_ENFORCE_LT( + i, n, platform::errors::InvalidArgument( + "Cannot find suitable sharding which is between [%d, %d)", + start_size, end_size)); + *start_idx = i; + *start_numel_offset = start_size - infos[i].numel_offset; + auto j = static_cast( + std::lower_bound(infos.begin(), infos.end(), end_size, comp) - + infos.begin()); + *end_idx = j - 1; + *end_numel_offset = end_size - infos[j - 1].numel_offset; + PADDLE_ENFORCE_GT(*end_numel_offset, 0, + platform::errors::InvalidArgument( + "Internal error when sharding, this may be a bug " + "caused by empty parameter.")); + VLOG(10) << "Sharding [start_size=" << start_size << ", end_size=" << end_size + << "): " << (*start_idx) << ":" << (*start_numel_offset) << " -> " + << (*end_idx) << ":" << (*end_numel_offset); +} + +static size_t FillAlignmentPaddingInfo(std::vector *infos, + size_t alignment, size_t nranks, + pten::DataType dtype) { + auto sizeof_dtype = paddle::experimental::SizeOf(dtype); + PADDLE_ENFORCE_EQ( + alignment % sizeof_dtype, 0, + platform::errors::InvalidArgument( + "The attr(alignment) should be exactly divided by sizeof(T) %d.", + sizeof_dtype)); + alignment /= sizeof_dtype; + + size_t total_numel_sum_with_padding = 0; + size_t n = infos->size(); + auto lcm = LCM(alignment, nranks); + for (size_t i = 0; i < n; ++i) { + auto &info = (*infos)[i]; + size_t numel_with_padding = + GetAlignSize(info.numel, i + 1 == n ? lcm : alignment); + info.numel_with_padding = numel_with_padding; + info.numel_offset = total_numel_sum_with_padding; + total_numel_sum_with_padding += numel_with_padding; + } + return total_numel_sum_with_padding; +} + +template +static T *TensorFillConstant(const platform::CUDADeviceContext &dev_ctx, + framework::Tensor *tensor, + const framework::DDim &dims, T value) { + tensor->Resize(dims); + auto *ptr = tensor->mutable_data(dev_ctx.GetPlace()); + pten::funcs::SetConstant set_constant; + set_constant(dev_ctx, tensor, value); + return ptr; +} + +static framework::Tensor CastDataForInitedTensor( + const platform::CUDADeviceContext &dev_ctx, framework::Tensor *origin, + framework::Tensor *fused_out, size_t numel_offset) { + PADDLE_ENFORCE_EQ(origin->IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor to be cast should be initialized.")); + + PADDLE_ENFORCE_EQ(fused_out->dtype(), pten::DataType::FLOAT32, + platform::errors::InvalidArgument( + "The dst tensor to be cast should be FP32 tensor.")); + PADDLE_ENFORCE_EQ(origin->dtype(), pten::DataType::FLOAT16, + platform::errors::InvalidArgument( + "The src tensor to be cast should be FP16 tensor.")); + auto *dst = fused_out->data() + numel_offset; + auto *src = origin->data(); + auto numel = origin->numel(); + LaunchCastKernel(dev_ctx, src, dst, numel); + VLOG(10) << "Cast from FP32 -> FP16, range: [" << numel_offset << ", " + << numel_offset + numel << ")" + << " , total: [0, " << fused_out->numel() << ")"; + framework::DDim fused_out_dim = fused_out->dims(); + auto fused_out_numel = fused_out->numel(); + fused_out->Resize({fused_out_numel}); + auto sliced_tensor = fused_out->Slice(numel_offset, numel + numel_offset); + fused_out->Resize(fused_out_dim); + return sliced_tensor; +} + +static framework::Tensor CopyAndShareBufferForInitedTensor( + framework::Tensor *origin, framework::Tensor *fused_out, + size_t numel_offset, gpuStream_t stream) { + PADDLE_ENFORCE_EQ( + origin->IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor to be copied and shared data should be initialized.")); + auto dtype = fused_out->type(); + PADDLE_ENFORCE_EQ(origin->type(), dtype, + platform::errors::InvalidArgument( + "The tensor to be copied and shared data should be " + "have the same data type.")); + auto place = fused_out->place(); + PADDLE_ENFORCE_EQ( + origin->place(), place, + platform::errors::InvalidArgument("The tensor to be copied and shared " + "data should be have the same place.")); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(place), true, + platform::errors::InvalidArgument( + "The tensor to be copied and shared data should be on GPU place.")); + + auto numel = origin->numel(); + framework::DDim fused_out_dim = fused_out->dims(); + auto fused_out_numel = fused_out->numel(); + auto sliced_tensor = fused_out->Resize({fused_out_numel}) + .Slice(numel_offset, numel + numel_offset); + memory::Copy(place, sliced_tensor.data(), place, origin->data(), + numel * paddle::experimental::SizeOf(dtype), stream); + origin->ShareBufferWith(sliced_tensor); + fused_out->Resize(fused_out_dim); + VLOG(10) << "Copy and share buffer, range: [" << numel_offset << ", " + << numel_offset + numel << ") , total: [0, " << fused_out->numel() + << ") , dtype = " << dtype; + return sliced_tensor; +} + +static void ShareBufferForNonInitedTensor(framework::Tensor *origin, + framework::Tensor *fused_out, + size_t numel_offset, + const framework::DDim &dims) { + PADDLE_ENFORCE_EQ( + origin->IsInitialized(), false, + platform::errors::InvalidArgument( + "The tensor to be shared data should not be initialized.")); + + framework::DDim fused_out_dim = fused_out->dims(); + auto fused_out_numel = fused_out->numel(); + auto numel = framework::product(dims); + *origin = fused_out->Resize({fused_out_numel}) + .Slice(numel_offset, numel + numel_offset); + origin->Resize(dims); + fused_out->Resize(fused_out_dim); + VLOG(10) << "Share buffer for non-inited, range: [" << numel_offset << ", " + << numel_offset + numel << "), total: [0, " << fused_out->numel() + << ") , dtype = " << fused_out->dtype(); +} + +template +static __global__ void LambFillFusedIndicesCUDAKernel(const OffsetT *offsets, + IndexT *out, + int offset_num, + int out_num) { + CUDA_KERNEL_LOOP_TYPE(i, out_num, int) { + auto idx = pten::funcs::LowerBound(offsets, offset_num, i); + if (idx == offset_num || offsets[idx] != i) { + --idx; + } + out[i] = idx; + } +} + +template +static void CopyVectorToTensor(const std::vector &src, + framework::Tensor *dst, + const platform::Place &place, + gpuStream_t stream) { + dst->Resize({static_cast(src.size())}); + T *dst_ptr = dst->mutable_data(place); + const T *src_ptr = src.data(); + auto nbytes = src.size() * sizeof(T); + memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream); +} + +template +class DistributedFusedLambInitOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + VLOG(10) << "starts to run DistributedFusedLambInitOp"; + auto &dev_ctx = ctx.template device_context(); + auto place = ctx.GetPlace(); + auto stream = dev_ctx.stream(); + + // Step 1: Check Input(Param) and Output(ParamOut), Input(Grad) and + // Output(GradOut) + auto params = ctx.MultiInput("Param"); + auto grads = ctx.MultiInput("Grad"); + auto master_params = ctx.MultiOutput("MasterParamOut"); + std::vector fp32_infos, fp16_infos; + { + PADDLE_ENFORCE_EQ(params.size(), grads.size(), + platform::errors::InvalidArgument( + "The parameter number and parameter gradient " + "number should be the same.")); + + auto params_out = ctx.MultiOutput("ParamOut"); + auto grads_out = ctx.MultiOutput("GradOut"); + PADDLE_ENFORCE_EQ( + params.size(), params_out.size(), + platform::errors::InvalidArgument("Input(Param) and Output(ParamOut) " + "should have the same number.")); + PADDLE_ENFORCE_EQ( + grads.size(), grads_out.size(), + platform::errors::InvalidArgument( + "Input(Grad) and Output(GradOut) should have the same number.")); + size_t n = params.size(); + VLOG(10) << "parameter number: " << n; + for (size_t i = 0; i < n; ++i) { + auto *p = params[i]; + auto *g = grads[i]; + auto *p_out = params_out[i]; + auto *g_out = grads_out[i]; + + PADDLE_ENFORCE_NOT_NULL( + p, platform::errors::InvalidArgument( + "The %d-th parameter should not be nullptr.", i)); + PADDLE_ENFORCE_EQ(p->IsInitialized(), true, + platform::errors::InvalidArgument( + "The %d-th parameter should be initialized.", i)); + PADDLE_ENFORCE_EQ( + p->place(), place, + platform::errors::InvalidArgument( + "The %d-th parameter is not initialized on the right place.", + i)); + PADDLE_ENFORCE_EQ(p, p_out, + platform::errors::InvalidArgument( + "The %d-th Input(Param) and Output(ParamOut) " + "should be the same tensor.", + i)); + + auto dtype = p->dtype(); + PADDLE_ENFORCE_NOT_NULL( + g, platform::errors::InvalidArgument( + "The %d-th gradient should not be nullptr.", i)); + PADDLE_ENFORCE_EQ(g, g_out, + platform::errors::InvalidArgument( + "The %d-th Input(Grad) and Output(Grad) should " + "be the same tensor.")); + auto numel = p->numel(); + PADDLE_ENFORCE_GT(numel, 0, + platform::errors::InvalidArgument( + "The %d-th Input(Param) have no elements.")); + + void *g_data = nullptr; + if (g->IsInitialized()) { + PADDLE_ENFORCE_EQ(g->dtype(), dtype, + platform::errors::InvalidArgument( + "The %d-th Input(Param) and Input(Grad) should " + "have the same data type %s.", + i, dtype)); + PADDLE_ENFORCE_EQ(g->dims(), p->dims(), + platform::errors::InvalidArgument( + "The %d-th Input(Param) and Input(Grad) should " + "have the same shape.", + i)); + g_data = g_out->data(); + } + + ParamGradInfo *info; + if (dtype == pten::DataType::FLOAT32) { + fp32_infos.emplace_back(); + info = &fp32_infos.back(); + } else if (dtype == pten::DataType::FLOAT16) { + fp16_infos.emplace_back(); + info = &fp16_infos.back(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported data type %s.", dtype)); + } + + VLOG(10) << "Found " << dtype << " parameter " << i << " shape=[" + << p_out->dims() << "] numel=" << numel + << " grad.IsInitialized()=" + << (g_out->IsInitialized() ? "true" : "false"); + + info->param_t = p_out; + info->grad_t = g_out; + info->idx = i; + info->numel = numel; + info->numel_with_padding = 0; // not determined yet + info->numel_offset = 0; // not determined yet + } + } + VLOG(10) << "Fill ParamGradInfo ends"; + + // Step 2: determine the numel_with_padding and numel_offset + auto rank = ctx.Attr("rank"); + auto nranks = ctx.Attr("nranks"); + auto alignment = ctx.Attr("alignment"); + VLOG(10) << "rank = " << rank << ", nranks = " << nranks + << " , alignment = " << alignment; + if (alignment <= 0) { + alignment = platform::GpuMinChunkSize(); + } + PADDLE_ENFORCE_GE(alignment, 1, + platform::errors::InvalidArgument( + "The attr(alignment) should be larger than 0.")); + PADDLE_ENFORCE_EQ(alignment & (alignment - 1), 0, + platform::errors::InvalidArgument( + "The attr(alignment) should be the power of 2.")); + PADDLE_ENFORCE_GE( + rank, 0, platform::errors::InvalidArgument( + "The attr(rank) should be equal to or larger than 0.")); + PADDLE_ENFORCE_LT( + rank, nranks, + platform::errors::InvalidArgument( + "The attr(rank) should be less than the attr(nranks).")); + // NOTE: We guarantee that both fp32_numel and fp16_numel can be exactly + // divided by alignment and nranks. + auto fp32_numel = FillAlignmentPaddingInfo(&fp32_infos, alignment, nranks, + pten::DataType::FLOAT32); + VLOG(10) << "FP32 ParamGradInfo: " << string::join_strings(fp32_infos, " "); + auto fp16_numel = FillAlignmentPaddingInfo(&fp16_infos, alignment, nranks, + pten::DataType::FLOAT16); + VLOG(10) << "FP16 ParamGradInfo: " << string::join_strings(fp16_infos, " "); + auto total_numel = fp32_numel + fp16_numel; + PADDLE_ENFORCE_LT( + total_numel, std::numeric_limits::max(), + platform::errors::InvalidArgument("Too many parameter number.")); + + auto fp32_numel_each_device = fp32_numel / nranks; + auto fp16_numel_each_device = fp16_numel / nranks; + auto numel_each_device = fp32_numel_each_device + fp16_numel_each_device; + VLOG(10) << "Fill padding ends. total_numel = " << total_numel + << ", fp32_numel = " << fp32_numel + << ", fp16_numel = " << fp16_numel + << ", fp32_numel_each_device = " << fp32_numel_each_device + << ", fp16_numel_each_device = " << fp16_numel_each_device; + + // Step 3: allocate output tensor and do initialization + float *fused_fp32_param = nullptr, *fused_fp32_grad = nullptr; + platform::float16 *fused_fp16_param = nullptr, *fused_fp16_grad = nullptr; + framework::Tensor *fp32_p_t = nullptr, *fp16_p_t = nullptr, + *fp32_g_t = nullptr, *fp16_g_t = nullptr; + std::vector fp16_master_params; + if (total_numel > 0) { + fp32_p_t = ctx.Output("FP32FusedParam"); + fused_fp32_param = TensorFillConstant( + dev_ctx, fp32_p_t, {static_cast(total_numel)}, 0.0f); + } + + if (fp32_numel > 0) { + fp32_g_t = ctx.Output("FP32FusedGrad"); + fused_fp32_grad = TensorFillConstant( + dev_ctx, fp32_g_t, {static_cast(fp32_numel)}, 0.0f); + } + + if (fp16_numel > 0) { + fp16_p_t = ctx.Output("FP16FusedParam"); + fused_fp16_param = TensorFillConstant( + dev_ctx, fp16_p_t, {static_cast(fp16_numel)}, + static_cast(0)); + + fp16_g_t = ctx.Output("FP16FusedGrad"); + fused_fp16_grad = TensorFillConstant( + dev_ctx, fp16_g_t, {static_cast(fp16_numel)}, + static_cast(0)); + } + VLOG(10) << "Allocate FP32FusedParam/Grad, FP16FusedParam/Grad ends"; + + // (1) For FP32FusedParam, memcpy for fp32 param and then share data, cast + // for fp16 master weight + // (2) For FP16FusedParam, memcpy and then share data + // (3) For FP32FusedGrad/FP16FusedGrad, memcpy if gradient has been inited + for (const auto &info : fp32_infos) { + auto sliced_tensor = CopyAndShareBufferForInitedTensor( + info.param_t, fp32_p_t, info.numel_offset, stream); + master_params[info.idx]->Resize(info.param_t->dims()); + master_params[info.idx]->ShareBufferWith(sliced_tensor); + PADDLE_ENFORCE_EQ(master_params[info.idx]->mutable_data(place), + sliced_tensor.data(), + platform::errors::InvalidArgument( + "Invalid master weight tensor pointer.")); + if (info.grad_t->IsInitialized()) { + CopyAndShareBufferForInitedTensor(info.grad_t, fp32_g_t, + info.numel_offset, stream); + } else { + ShareBufferForNonInitedTensor(info.grad_t, fp32_g_t, info.numel_offset, + info.param_t->dims()); + } + } + + size_t fp16_numel_offset = 0; + if (fp32_numel > 0) { + auto last_fp32_info = fp32_infos.back(); + fp16_numel_offset = + last_fp32_info.numel_offset + last_fp32_info.numel_with_padding; + } + + for (const auto &info : fp16_infos) { + auto master_weight_offset = info.numel_offset + fp16_numel_offset; + auto sliced_tensor = CastDataForInitedTensor( + dev_ctx, info.param_t, fp32_p_t, master_weight_offset); + master_params[info.idx]->Resize(info.param_t->dims()); + master_params[info.idx]->ShareBufferWith(sliced_tensor); + + CopyAndShareBufferForInitedTensor(info.param_t, fp16_p_t, + info.numel_offset, stream); + PADDLE_ENFORCE_EQ(master_params[info.idx]->mutable_data(place), + sliced_tensor.data(), + platform::errors::InvalidArgument( + "Invalid master weight tensor pointer.")); + + if (info.grad_t->IsInitialized()) { + CopyAndShareBufferForInitedTensor(info.grad_t, fp16_g_t, + info.numel_offset, stream); + } else { + ShareBufferForNonInitedTensor(info.grad_t, fp16_g_t, info.numel_offset, + info.param_t->dims()); + } + } + VLOG(10) << "Copy/share data for Param/Grad ends"; + + // Step 4: For Moment1, Moment2, Beta1Pow, Beta2Pow, just fill constant + TensorFillConstant(dev_ctx, ctx.Output("Moment1"), + {static_cast(numel_each_device)}, 0.0f); + TensorFillConstant(dev_ctx, ctx.Output("Moment2"), + {static_cast(numel_each_device)}, 0.0f); + TensorFillConstant(dev_ctx, + ctx.Output("Beta1Pow"), {1}, + ctx.Attr("beta1")); + TensorFillConstant(dev_ctx, + ctx.Output("Beta2Pow"), {1}, + ctx.Attr("beta2")); + VLOG(10) << "Init Moment and BetaPow ends"; + + // Step 5: Do sharding + size_t fp32_start_idx, fp32_end_idx, fp32_start_numel_offset, + fp32_end_numel_offset; + GetParamGradShardInfo(fp32_infos, rank * fp32_numel_each_device, + (rank + 1) * fp32_numel_each_device, &fp32_start_idx, + &fp32_end_idx, &fp32_start_numel_offset, + &fp32_end_numel_offset); + size_t fp16_start_idx, fp16_end_idx, fp16_start_numel_offset, + fp16_end_numel_offset; + GetParamGradShardInfo(fp16_infos, rank * fp16_numel_each_device, + (rank + 1) * fp16_numel_each_device, &fp16_start_idx, + &fp16_end_idx, &fp16_start_numel_offset, + &fp16_end_numel_offset); + size_t fp32_local_param_num = + fp32_numel_each_device > 0 ? fp32_end_idx - fp32_start_idx + 1 : 0; + size_t fp16_local_param_num = + fp16_numel_each_device > 0 ? fp16_end_idx - fp16_start_idx + 1 : 0; + size_t total_local_param_num = fp32_local_param_num + fp16_local_param_num; + VLOG(10) << "Found the sharding arguments"; + + auto *param_info_t = ctx.Output("ParamInfo"); + param_info_t->Resize({6}); + auto *param_info = param_info_t->mutable_data(platform::CPUPlace()); + param_info[0] = static_cast(fp32_start_idx); + param_info[1] = static_cast(fp32_local_param_num); + param_info[2] = static_cast(fp32_infos.size()); + param_info[3] = static_cast(fp16_start_idx + fp32_infos.size()); + param_info[4] = static_cast(fp16_local_param_num); + param_info[5] = static_cast(fp16_infos.size()); + + VLOG(10) << "Start FP32 idx: " << param_info[0]; + VLOG(10) << "Local FP32 param num: " << param_info[1]; + VLOG(10) << "Global FP32 param num: " << param_info[2]; + + VLOG(10) << "Start FP16 idx: " << param_info[3]; + VLOG(10) << "Local FP16 param num: " << param_info[4]; + VLOG(10) << "Global FP16 param num: " << param_info[5]; + + // For WeightDecay, shard and perform H2D copy + const auto &origin_weight_decay = + ctx.Attr>("weight_decay"); + PADDLE_ENFORCE_EQ(params.size(), origin_weight_decay.size(), + platform::errors::InvalidArgument( + "The attr(weight_decay) should have the " + "same length with Input(Param).")); + std::vector shard_weight_decay; + shard_weight_decay.reserve(total_local_param_num); + for (size_t i = 0; i < fp32_local_param_num; ++i) { + shard_weight_decay.push_back( + origin_weight_decay[fp32_infos[i + fp32_start_idx].idx]); + } + for (size_t i = 0; i < fp16_local_param_num; ++i) { + shard_weight_decay.push_back( + origin_weight_decay[fp16_infos[i + fp16_start_idx].idx]); + } + + // For FusedIndices, launch CUDA kernel to do binary search + auto *fused_indices_t = ctx.Output("FusedIndices"); + fused_indices_t->Resize({static_cast(total_numel)}); + auto *fused_indices = fused_indices_t->mutable_data(place); + std::vector numel_offsets; + numel_offsets.reserve(params.size() + 1); + for (const auto &info : fp32_infos) { + numel_offsets.push_back(info.numel_offset); + } + for (const auto &info : fp16_infos) { + numel_offsets.push_back(info.numel_offset + fp16_numel_offset); + } + numel_offsets.push_back(fp32_numel + fp16_numel); + PADDLE_ENFORCE_EQ(numel_offsets.size(), params.size() + 1, + platform::errors::InvalidArgument( + "The numel_offsets number must be one larger than " + "the parameter number.")); + VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets); + auto *fused_param_offset_t = + ctx.Output("FusedParamOffsets"); + fused_param_offset_t->Resize({static_cast(numel_offsets.size())}); + auto *fused_param_offset = fused_param_offset_t->mutable_data(place); + memory::Copy(place, fused_param_offset, platform::CPUPlace(), + numel_offsets.data(), + numel_offsets.size() * sizeof(numel_offsets[0]), stream); + auto config = platform::GetGpuLaunchConfig1D(dev_ctx, total_numel); + LambFillFusedIndicesCUDAKernel<<>>( + fused_param_offset, fused_indices, numel_offsets.size() - 1, + total_numel); + + std::vector lengths; + lengths.reserve(fp32_local_param_num + fp16_local_param_num); + + std::vector fp32_partial_numel_offsets; + fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1); + fp32_partial_numel_offsets.push_back(0); + // Fill the partial_numel_offsets + for (size_t i = fp32_start_idx; i < fp32_start_idx + fp32_local_param_num; + ++i) { + size_t valid_start_n = 0; + if (i == fp32_start_idx) { + valid_start_n = fp32_start_numel_offset; + } + + size_t end_n = fp32_infos[i].numel_with_padding; + if (i + 1 == fp32_start_idx + fp32_local_param_num) { + end_n = std::min(end_n, fp32_end_numel_offset); + } + + PADDLE_ENFORCE_NE(valid_start_n, end_n, + platform::errors::InvalidArgument( + "Indices sharding error. This may be a bug.")); + VLOG(10) << "FP32 Partial numel = [" + << valid_start_n + fp32_infos[i].numel << "," + << end_n + fp32_infos[i].numel; + lengths.push_back(end_n - valid_start_n); + fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() + + lengths.back()); + } + + std::vector fp16_partial_numel_offsets; + fp16_partial_numel_offsets.reserve(fp16_local_param_num + 1); + fp16_partial_numel_offsets.push_back(0); + for (size_t i = fp16_start_idx; i < fp16_start_idx + fp16_local_param_num; + ++i) { + size_t valid_start_n = 0; + if (i == fp16_start_idx) { + valid_start_n = fp16_start_numel_offset; + } + + size_t end_n = fp16_infos[i].numel_with_padding; + if (i + 1 == fp16_start_idx + fp16_local_param_num) { + end_n = std::min(end_n, fp16_end_numel_offset); + } + + PADDLE_ENFORCE_NE(valid_start_n, end_n, + platform::errors::InvalidArgument( + "Indices sharding error. This may be a bug.")); + lengths.push_back(end_n - valid_start_n); + fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() + + lengths.back()); + } + + CopyVectorToTensor( + fp32_partial_numel_offsets, + ctx.Output("FP32ShardFusedParamOffsets"), place, + stream); + CopyVectorToTensor( + fp16_partial_numel_offsets, + ctx.Output("FP16ShardFusedParamOffsets"), place, + stream); + + // Fill the weight decay tensor + PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(), + platform::errors::InvalidArgument( + "Invalid weight decay sharding. This may be a bug.")); + std::vector wd_cpu; + for (size_t i = 0; i < shard_weight_decay.size(); ++i) { + int len = lengths[i]; + for (int j = 0; j < len; ++j) { + wd_cpu.push_back(shard_weight_decay[i]); + } + } + PADDLE_ENFORCE_EQ(wd_cpu.size() * nranks, fp32_numel + fp16_numel, + platform::errors::InvalidArgument( + "Invalid weight decay sharding. This may be a bug.")); + CopyVectorToTensor(wd_cpu, ctx.Output("WeightDecay"), + place, stream); + + auto *global_scale = ctx.Output("GlobalScale"); + if (!global_scale->IsInitialized()) { + TensorFillConstant(dev_ctx, global_scale, {1}, 1.0f); + } + VLOG(10) << "Init global scale ends"; + dev_ctx.Wait(); + VLOG(10) << "Wait for H2D copy"; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + distributed_fused_lamb_init, + ops::DistributedFusedLambInitOpKernel); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h new file mode 100644 index 00000000000..bbb8a28a5b2 --- /dev/null +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class DistributedFusedLambInitOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "The distributed_fused_lamb_init operator does not support CPU yet.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc new file mode 100644 index 00000000000..748f8206adb --- /dev/null +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" + +namespace paddle { +namespace operators { + +class DistributedFusedLambOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override {} + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto dtype = framework::proto::VarType::FP32; // dtype is not important + return framework::OpKernelType(dtype, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "ParamInfo") { + return expected_kernel_type; + } else { + return framework::OperatorWithKernel::GetKernelTypeForVar( + var_name, tensor, expected_kernel_type); + } + } +}; + +class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", "The initial parameter list.").AsDuplicable(); + AddInput("Grad", "The initial gradient list.").AsDuplicable(); + + AddInput("FP32FusedParam", + "The fp32 fused param and fp16 fused master weight tensor. Its " + "shape is [M1+M2], where M1 is the fp32 fused parameter size and " + "M2 is the fp16 fused master weight parameter size. Note that M1 " + "and M2 should be exactly divided by N (guaranteed by extra " + "padding 0), where N is the world size.") + .AsDispensable(); + AddInput("FP32FusedGrad", "The fp32 fused grad tensor. Its shape is [M1].") + .AsDispensable(); + AddInput("FP16FusedParam", + "The fp16 fused param tensor. Its shape is [M2].") + .AsDispensable(); + AddInput("FP16FusedGrad", "The fp16 fused grad tensor. Its shape is [M2].") + .AsDispensable(); + + AddInput("Moment1", + "The sharded fp32 moment1 tensor. Its shape is [(M1+M2)/N]."); + AddInput("Moment2", + "The sharded fp32 moment2 tensor. Its shape is [(M1+M2)/N]."); + AddInput("Beta1Pow", + "The fp32 beta1 power accumulator tensor. Its shape is [1]."); + AddInput("Beta2Pow", + "The fp32 beta2 power accumulator tensor. Its shape is [1]."); + AddInput("FusedIndices", + "The param index of each element in FP32FusedParam. Its shape is " + "[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...]."); + AddInput( + "FusedParamOffsets", + "The numel offset of each parameter inside the FP32FusedParam. Its " + "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " + "+ n_2, ...]."); + AddInput("FP32ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp32_local_param_num + 1]."); + AddInput("FP16ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp16_local_param_num + 1]."); + AddInput("WeightDecay", + "The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N]."); + AddInput("ParamInfo", + "The param info. It should be in CPUPlace, and its shape is [6]" + "CPUPlace, and its shape is [6]. It is " + "[fp32_shard_param_start_idx, fp32_local_param_num, " + "fp32_global_param_num, fp16_shard_param_start_idx, " + "fp16_local_param_num, fp16_global_param_num]."); + + AddInput("LearningRate", + "The fp32 learning rate tensor. Its shape is [1]."); + AddInput("GlobalScale", "The fp32 global scale tensor. Its shape is [1]."); + + AddOutput("FP32FusedParamOut", "The updated FP32FusedParam.") + .AsDispensable(); + AddOutput("FP16FusedParamOut", "The updated FP16FusedParam.") + .AsDispensable(); + + AddOutput("Moment1Out", "The updated Moment1."); + AddOutput("Moment2Out", "The updated Moment2."); + AddOutput("Beta1PowOut", "The updated Beta1Pow."); + AddOutput("Beta2PowOut", "The updated Beta2Pow."); + + AddOutput("ParamOut", "The updated output parameter tensor list.") + .AsDuplicable(); + + AddOutput("FoundInf", "Whether there is NaN/Inf"); + + AddAttr("beta1", "The initial Beta1Pow value."); + AddAttr("beta2", "The initial Beta2Pow value."); + AddAttr("epsilon", + "The epsilon value to maintain numeric stability."); + AddAttr( + "max_global_grad_norm", + "The maximum global gradient l2-norm value for clipping. If " + "max_global_grad_norm <= 0, no clipping would be performed."); + AddAttr("clip_after_allreduce", + "Whether to clip before allreduce, only valid when the " + "world size is larger than 1."); + AddAttr( + "use_master_param_norm", + "Whether to use master parameter to calculate " + "the L2-Norm. If it is true, it would be more accurate but be more " + "NCCL communication data. If it is false, it would be less accurate " + "and be less NCCL communication data.") + .SetDefault(true); + AddAttr("is_grad_scaled_by_nranks", + "Whether the input gradient has been scaled by nranks.") + .SetDefault(true); + AddAttr("ring_id", "The ring id of the NCCL communicator.") + .SetDefault(0); + AddComment("The DistributedFusedLamb optimizer."); + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb, + ops::DistributedFusedLambOp, + ops::DistributedFusedLambOpMaker); + +REGISTER_OP_CPU_KERNEL( + distributed_fused_lamb, + ops::DistributedFusedLambOpKernel); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu new file mode 100644 index 00000000000..15729207158 --- /dev/null +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -0,0 +1,1305 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/memory/buffer.h" +#include "paddle/fluid/operators/optimizers/cast_with_ptr.h" +#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" +#include "paddle/fluid/operators/tensor_to_string.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/string/string_helper.h" +#include "paddle/pten/core/utils/data_type.h" + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#include "math.h" // NOLINT +#endif + +#ifdef __HIPCC__ +#include +#include "math.h" // NOLINT +namespace cub = hipcub; +#endif + +namespace paddle { +namespace operators { + +template +using MasterT = typename details::MPTypeTrait::Type; + +template +static void LogParamAndTrustRatioDivSquareNorm( + const framework::ExecutionContext &ctx, const float *param_square_norm, + const float *trust_ratio_div_square_norm) { + if (!VLOG_IS_ON(LogLevel)) return; + + auto tensors = ctx.MultiInput("Param"); + if (tensors.empty()) return; + + size_t n = tensors.size(); + auto place = tensors[0]->place(); + + auto pn_vec = ToVector(param_square_norm, n, place); + auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place); + + std::vector fp32_indices, fp16_indices; + fp32_indices.reserve(n); + fp16_indices.reserve(n); + for (size_t i = 0; i < n; ++i) { + const auto *t = tensors[i]; + if (t->dtype() == pten::DataType::FLOAT32) { + fp32_indices.push_back(i); + } else if (t->dtype() == pten::DataType::FLOAT16) { + fp16_indices.push_back(i); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported data type %s.", t->dtype())); + } + } + + for (auto idx : fp16_indices) { + fp32_indices.push_back(idx); + } + + const auto &names = ctx.GetOp().Inputs("Param"); + for (size_t i = 0; i < fp32_indices.size(); ++i) { + auto idx = fp32_indices[i]; + VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx] + << " pn = " << pn_vec[i] << " , tn = " << tn_vec[i]; + } +} + +static bool IsFinite(const platform::CUDADeviceContext &dev_ctx, + const float *ptr) { + auto stream = dev_ctx.stream(); + float cpu_value; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(&cpu_value, ptr, sizeof(float), + hipMemcpyDeviceToHost, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&cpu_value, ptr, sizeof(float), + cudaMemcpyDeviceToHost, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); +#endif + LOG(INFO) << "NAN_INF indicator value: " << cpu_value; + return isfinite(cpu_value); +} + +template +static const T *GetInputTensorPtr(const framework::ExecutionContext &ctx, + const char *in_name, + int64_t *numel = nullptr) { + const auto *in_tensor = ctx.Input(in_name); + PADDLE_ENFORCE_NOT_NULL(in_tensor, platform::errors::InvalidArgument( + "Input(%s) cannot be NULL.", in_name)); + if (in_tensor->IsInitialized()) { + if (numel) *numel = in_tensor->numel(); + return in_tensor->data(); + } else { + if (numel) *numel = 0; + return nullptr; + } +} + +template +static T *GetSameInOutTensorPtr(const framework::ExecutionContext &ctx, + const platform::Place &place, + const char *in_name, const char *out_name, + int64_t *numel = nullptr) { + const auto *in_tensor = ctx.Input(in_name); + if (in_tensor == nullptr || !in_tensor->IsInitialized()) { + PADDLE_ENFORCE_EQ(AllowNotExist, true, + platform::errors::InvalidArgument( + "Input(%s) cannot be NULL.", in_name)); + if (numel) *numel = 0; + return nullptr; + } + + auto *out_tensor = ctx.Output(out_name); + PADDLE_ENFORCE_NOT_NULL(in_tensor, platform::errors::InvalidArgument( + "Input(%s) cannot be NULL.", in_name)); + PADDLE_ENFORCE_NOT_NULL(out_tensor, + platform::errors::InvalidArgument( + "Output(%s) cannot be NULL.", out_name)); + const T *in_data = in_tensor->data(); + T *out_data = out_tensor->mutable_data(place); + PADDLE_ENFORCE_EQ(in_data, out_data, + platform::errors::InvalidArgument( + "Input(%s) and Output(%s) must be the same Tensor.", + in_name, out_name)); + if (numel) *numel = out_tensor->numel(); + return out_data; +} + +template +struct SquareFunctor { + HOSTDEVICE MasterT operator()(T x) const { + auto y = static_cast>(x); + return y * y; + } +}; + +template +struct IsNanInfFunctor { + HOSTDEVICE bool operator()(T x) const { return !isfinite(x); } +}; + +struct OrFunctor { + HOSTDEVICE bool operator()(bool x, bool y) const { return x || y; } +}; + +struct AndFunctor { + HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; } +}; + +template +static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x, + const T2 *__restrict__ scale, + T1 *__restrict__ y, int num) { + static_assert(sizeof(T1) <= sizeof(T2), + "sizeof(T1) must be not greater than sizeof(T2)."); + T2 s = scale[0]; + CUDA_KERNEL_LOOP(i, num) { + y[i] = static_cast(static_cast(x[i]) * s); + } +} + +template +static __global__ void AddToCUDAKernel(const T *__restrict__ x, + T *__restrict__ y) { + y[0] += x[0]; +} + +// If clip before allreduce, +// coeff = global_scale * max_global_grad_norm / (1e-6 + sqrt(square_grad_norm) +// * rescale_grad) +// if coeff >= 1 or coeff is Nan/Inf, scale = 1.0 +// else scale = coeff +template +static __global__ void CalcGradNormClipBeforeAllReduceScale( + const T1 *__restrict__ global_scale, T1 max_global_grad_norm, + const T1 *__restrict__ square_grad_norm, T1 *__restrict__ out1, + T2 *__restrict__ out2, T1 clip_rescale_grad) { + T1 grad_norm = static_cast(sqrt(*square_grad_norm)) * clip_rescale_grad; + T1 scale = global_scale[0] * max_global_grad_norm / (1e-6 + grad_norm); + bool found_nan_inf = !isfinite(scale); + if (scale >= 1 || found_nan_inf) { + scale = static_cast(1.0); + } + + if (out1) { + *out1 = scale; + } + if (out2) { + *out2 = static_cast(scale); + } +} + +static __global__ void SetNanInfValueCUDAKernelOneFlag(const bool *in_flag_p, + float *out_p) { + *out_p = (*in_flag_p) ? __int_as_float(0x7fffffffU) : 0.0f; +} + +static __global__ void SetNanInfValueCUDAKernelTwoFlag(const bool *in_flag_p_1, + const bool *in_flag_p_2, + float *out_p) { + *out_p = + ((*in_flag_p_1) || (*in_flag_p_2)) ? __int_as_float(0x7fffffffU) : 0.0f; +} + +// TODO(zengjinle): Vectorize this function +// NOTE: this method does not update Beta1Pow and Beta2Pow! +template +static __global__ void UpdateLambMoment( + const T *__restrict__ param_p, const GradT *__restrict__ grad_p, + const T *__restrict__ square_grad_norm_p, + const T *__restrict__ global_scale, const IndexT *__restrict__ indices, + const T *__restrict__ weight_decay_p, const T *__restrict__ beta1pow_p, + const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p, + T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p, T beta1, T beta2, + T epsilon, T max_global_grad_norm, int num, T rescale_grad) { + T square_grad_norm = *square_grad_norm_p; + if (!isfinite(square_grad_norm)) return; + + T scale = rescale_grad / global_scale[0]; + if (max_global_grad_norm > 0) { + T clip_scale = + max_global_grad_norm / (sqrtf(square_grad_norm) * scale + 1e-6); + if (clip_scale < static_cast(1)) { + scale *= clip_scale; + } + } + + T one_minus_beta1pow = 1 - beta1pow_p[0]; + T one_minus_beta2pow = 1 - beta2pow_p[0]; + + CUDA_KERNEL_LOOP(i, num) { + T p = param_p[i]; + T g = static_cast(grad_p[i]) * scale; + T weight_decay = weight_decay_p[i]; + T mom1 = mom1_p[i]; + T mom2 = mom2_p[i]; + + mom1 = beta1 * mom1 + (1 - beta1) * g; + mom2 = beta2 * mom2 + (1 - beta2) * g * g; + + T mom1_unbiased = mom1 / one_minus_beta1pow; + T mom2_unbiased = mom2 / one_minus_beta2pow; + T trust_ratio_div = + mom1_unbiased / (sqrtf(mom2_unbiased) + epsilon) + weight_decay * p; + + mom1_p[i] = mom1; + mom2_p[i] = mom2; + trust_ratio_div_p[i] = trust_ratio_div; + } +} + +template +struct LambBetaPowUpdateOnceHelper { + LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) { + PADDLE_ENFORCE_NOT_NULL(beta1pow, + platform::errors::InvalidArgument( + "The beta1pow should not be nullptr.")); + PADDLE_ENFORCE_NOT_NULL(beta2pow, + platform::errors::InvalidArgument( + "The beta2pow should not be nullptr.")); + beta1pow_ = beta1pow; + beta2pow_ = beta2pow; + beta1_ = beta1; + beta2_ = beta2; + } + + HOSTDEVICE void UpdateBetaPows() const { + beta1pow_[0] *= beta1_; + beta2pow_[0] *= beta2_; + } + + private: + T *__restrict__ beta1pow_; + T *__restrict__ beta2pow_; + T beta1_; + T beta2_; +}; + +template +struct LambBetaPowUpdateOnceHelper { + LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) { + PADDLE_ENFORCE_EQ( + beta1pow, nullptr, + platform::errors::InvalidArgument("The beta1pow should be nullptr.")); + PADDLE_ENFORCE_EQ( + beta2pow, nullptr, + platform::errors::InvalidArgument("The beta2pow should be nullptr.")); + } + + HOSTDEVICE void UpdateBetaPows() const {} +}; + +template +struct LambFoundInfHelper { + public: + explicit LambFoundInfHelper(bool *found_inf) : found_inf_(found_inf) { + PADDLE_ENFORCE_NOT_NULL(found_inf, + platform::errors::InvalidArgument( + "The found_inf should not be nullptr.")); + } + + HOSTDEVICE void UpdateFoundInf(bool value) { *found_inf_ = value; } + + private: + bool *__restrict__ found_inf_; +}; + +template <> +struct LambFoundInfHelper { + public: + explicit LambFoundInfHelper(bool *found_inf) { + PADDLE_ENFORCE_EQ( + found_inf, nullptr, + platform::errors::InvalidArgument("The found_inf should be nullptr.")); + } + + HOSTDEVICE void UpdateFoundInf(bool) {} +}; + +template +struct LambParamHelper { + LambParamHelper(T *param, MasterT *master_param) { + constexpr bool kIsSameType = std::is_same>::value; + PADDLE_ENFORCE_EQ(kIsSameType, false, + platform::errors::InvalidArgument( + "T must not be the same with MasterT.")); + PADDLE_ENFORCE_NOT_NULL(master_param, + platform::errors::InvalidArgument( + "Master parameter must be provided.")); + param_ = param; + master_param_ = master_param; + } + + HOSTDEVICE void SetParam(int i, MasterT updated_p) { + param_[i] = static_cast(updated_p); + master_param_[i] = updated_p; + } + + HOSTDEVICE MasterT GetParam(int i) { return master_param_[i]; } + + private: + T *__restrict__ param_; + MasterT *__restrict__ master_param_; +}; + +template +struct LambParamHelper { + LambParamHelper(T *param, MasterT *master_param) { + constexpr bool kIsSameType = std::is_same>::value; + PADDLE_ENFORCE_EQ(kIsSameType, true, + platform::errors::InvalidArgument( + "T must be the same with MasterT.")); + if (master_param != nullptr) { + PADDLE_ENFORCE_EQ(static_cast(param), + static_cast(master_param), + platform::errors::InvalidArgument( + "Master parameter must be nullptr or the same as " + "non-master parameter.")); + } + param_ = param; + } + + HOSTDEVICE void SetParam(int i, MasterT updated_p) { + param_[i] = static_cast(updated_p); + } + + HOSTDEVICE MasterT GetParam(int i) { + return static_cast>(param_[i]); + } + + private: + T *__restrict__ param_; +}; + +template +struct LambParamAndBetaPowsUpdateHelper + : public LambParamHelper, + public LambBetaPowUpdateOnceHelper, NeedUpdateBetaPow>, + public LambFoundInfHelper { + LambParamAndBetaPowsUpdateHelper( + ParamT *param, MasterT *master_param, MasterT *beta1pow, + MasterT *beta2pow, MasterT beta1, MasterT beta2, + bool *found_inf, const MasterT *trust_ratio_div, + const MasterT *lr, const IndexT *index, + const MasterT *param_square_norm, + const MasterT *trust_ratio_div_square_norm, + const MasterT *update_flag) + : LambParamHelper(param, master_param), + LambBetaPowUpdateOnceHelper, NeedUpdateBetaPow>( + beta1pow, beta2pow, beta1, beta2), + LambFoundInfHelper(found_inf), + trust_ratio_div(trust_ratio_div), + lr(lr), + index(index), + param_square_norm(param_square_norm), + trust_ratio_div_square_norm(trust_ratio_div_square_norm), + update_flag(update_flag) {} + + const MasterT *__restrict__ trust_ratio_div; + const MasterT *__restrict__ lr; + const IndexT *__restrict__ index; + const MasterT *__restrict__ param_square_norm; + const MasterT *__restrict__ trust_ratio_div_square_norm; + const MasterT *__restrict__ update_flag; +}; + +template +static __global__ void LambUpdateParamAndBetaPowsCUDAKernel( + LambParamAndBetaPowsUpdateHelper + args, + int num) { + auto should_update = *args.update_flag; + if (!isfinite(should_update)) { + if (HasFoundInf && threadIdx.x == 0 && blockIdx.x == 0) { + args.UpdateFoundInf(true); + } + return; + } else if (HasFoundInf && threadIdx.x == 0 && blockIdx.x == 0) { + args.UpdateFoundInf(false); + } + + if (NeedUpdateBetaPow && threadIdx.x == 0 && blockIdx.x == 0) { + args.UpdateBetaPows(); + } + + using MT = MasterT; + + MT lr_value = *args.lr; + CUDA_KERNEL_LOOP(i, num) { + MT p = args.GetParam(i); + MT t = args.trust_ratio_div[i]; + auto norm_idx = args.index[i]; + MT p_square_norm = args.param_square_norm[norm_idx]; + MT t_square_norm = args.trust_ratio_div_square_norm[norm_idx]; + + MT p_norm = static_cast(sqrtf(p_square_norm)); + MT t_norm = static_cast(sqrtf(t_square_norm)); + + auto update = (p_norm != static_cast(0) && t_norm != static_cast(0)) + ? p_norm / t_norm + : static_cast(1); + + MT updated_p = p - lr_value * update * t; + args.SetParam(i, updated_p); + } +} + +template +static void LambUpdateParamAndBetaPows( + const platform::CUDADeviceContext &dev_ctx, + const MasterT *trust_ratio_div, const MasterT *lr, + const IndexT *index, const MasterT *param_square_norm, + const MasterT *trust_ratio_div_square_norm, + const MasterT *update_flag, MasterT **beta1pow, + MasterT **beta2pow, bool **found_inf, MasterT beta1, + MasterT beta2, int num, ParamT *param, + MasterT *master_param, gpuStream_t stream) { + if (num == 0) return; + + bool has_master_param = !(std::is_same>::value); + auto has_beta_pow = (*beta1pow) != nullptr && (*beta2pow) != nullptr; + auto has_found_inf = (*found_inf) != nullptr; + +#define PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL( \ + __has_master_param, __has_beta_pow, __has_found_inf) \ + do { \ + LambParamAndBetaPowsUpdateHelper \ + helper(param, master_param, *beta1pow, *beta2pow, beta1, beta2, \ + *found_inf, trust_ratio_div, lr, index, param_square_norm, \ + trust_ratio_div_square_norm, update_flag); \ + auto config = platform::GetGpuLaunchConfig1D(dev_ctx, num); \ + LambUpdateParamAndBetaPowsCUDAKernel<<< \ + config.block_per_grid, config.thread_per_block, 0, stream>>>(helper, \ + num); \ + } while (0) + + if (has_master_param) { + if (has_beta_pow) { + if (has_found_inf) { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, true, true); + } else { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, true, false); + } + } else { + if (has_found_inf) { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, false, true); + } else { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, false, false); + } + } + } else { + if (has_beta_pow) { + if (has_found_inf) { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, true, true); + } else { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, true, false); + } + } else { + if (has_found_inf) { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, false, true); + } else { + PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, false, false); + } + } + } + + *beta1pow = nullptr; + *beta2pow = nullptr; + *found_inf = nullptr; +#undef PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL +} + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype, + ncclComm_t comm, const void *scale, + ncclRedOp_t *op) { +#if NCCL_VERSION_CODE >= 21100 + int ver; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetVersion(&ver)); + if (ver >= 21100) { + VLOG(10) << "ncclRedOpCreatePreMulSum is supported."; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpCreatePreMulSum( + op, const_cast(scale), dtype, ncclScalarDevice, comm)); + return true; + } +#endif + VLOG(10) << "ncclRedOpCreatePreMulSum is not supported."; + return false; +} + +template +static void NCCLReduceScatterWithScale( + const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks, + ncclComm_t comm, gpuStream_t stream, + const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) { + static_assert(std::is_same::value || + std::is_same::value, + "T must be either float32 or float16."); + if (recvcount == 0) return; + + if (comm == nullptr) { + if (scale != nullptr) { + PADDLE_ENFORCE_EQ(nranks, 1, + platform::errors::InvalidArgument( + "nranks must be 1 when scale != nullptr.")); + auto numel = recvcount * nranks; + auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel); + ScaleCUDAKernel<<>>(sendbuff, scale, recvbuff, numel); + } + return; + } + + ncclRedOp_t op = ncclSum; + ncclDataType_t dtype = + std::is_same::value ? ncclFloat32 : ncclFloat16; + bool should_destroy_op = + scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op); + memory::Buffer buffer(dev_ctx.GetPlace()); + if (scale && !should_destroy_op) { + size_t numel = recvcount * nranks; + T *new_sendbuff = buffer.Alloc(numel); + auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel); + ScaleCUDAKernel<<>>(sendbuff, scale, new_sendbuff, numel); + sendbuff = new_sendbuff; + } + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( + sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + +#if NCCL_VERSION_CODE >= 21100 + if (should_destroy_op) { + VLOG(10) << "ncclRedOpDestroy starts"; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRedOpDestroy(op, comm)); + VLOG(10) << "ncclRedOpDestroy ends"; + } +#endif +} +#endif + +template +static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out, + int num_items, ReduceOpT reduction_op, T init, + gpuStream_t stream, memory::Buffer *buffer) { + void *d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out, + num_items, reduction_op, init, stream)); + d_temp_storage = buffer->Alloc(temp_storage_bytes); + VLOG(10) << "cub::DeviceReduce::Reduce needs " << temp_storage_bytes + << " byte(s), ptr = " << d_temp_storage; + PADDLE_ENFORCE_GPU_SUCCESS( + cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, d_in, d_out, + num_items, reduction_op, init, stream)); +} + +template +static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + ReductionOp reduction_op, T initial_value, + gpuStream_t stream, + memory::Buffer *buffer) { + void *d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce( + d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, + d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream)); + d_temp_storage = buffer->Alloc(temp_storage_bytes); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce( + d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, + d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream)); +} + +template +struct AddConstantFunctor { + explicit AddConstantFunctor(T bias) : bias_(bias) {} + + T operator()(T x) const { return x + bias_; } + + private: + T bias_; +}; + +template +struct OffsetWithBiasFunctor { + OffsetWithBiasFunctor(const T *offset, T bias) + : offset_(offset), bias_(bias) {} + + HOSTDEVICE T operator()(T idx) const { return offset_[idx] - bias_; } + + HOSTDEVICE constexpr bool operator==(const OffsetWithBiasFunctor &) const { + return true; + } + + private: + const T *offset_; + const T bias_; +}; + +template +static void CubDeviceSegmentedSquareNorm(const T *x, MasterT *y, int n, + const OffsetT *offset, + OffsetT init_offset, + gpuStream_t stream, + memory::Buffer *buffer) { + if (n <= 0) return; + cub::TransformInputIterator, SquareFunctor, const T *> iter( + x, SquareFunctor()); + if (init_offset == static_cast(0)) { + CubDeviceSegmentedReduce(iter, y, n, offset, offset + 1, cub::Sum(), + static_cast>(0), stream, buffer); + } else { + cub::CountingInputIterator cnt_iter(0); + OffsetWithBiasFunctor functor(offset, init_offset); + cub::TransformInputIterator, + cub::CountingInputIterator> + offset_iter(cnt_iter, functor); + CubDeviceSegmentedReduce(iter, y, n, offset_iter, offset_iter + 1, + cub::Sum(), static_cast>(0), stream, + buffer); + } +} + +template +static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm, + gpuStream_t stream, + memory::Buffer *cub_tmp_buffer) { + using Iterator = + cub::TransformInputIterator, const T *>; + Iterator iter(grad, SquareFunctor()); + CubDeviceReduce(iter, square_norm, n, cub::Sum(), static_cast(0), + stream, cub_tmp_buffer); +} + +// square_norm is of length 2 at least +static void GetSquareGradNorm(const float *fp32_grad, int fp32_numel, + const platform::float16 *fp16_grad, + int fp16_numel, float *square_norm, + gpuStream_t stream, + memory::Buffer *cub_tmp_buffer) { + VLOG(10) << "GetSquareGradNorm starts, fp32_numel = " << fp32_numel + << " , fp16_numel = " << fp16_numel; + if (fp32_numel > 0) { + GetSquareGradNormImpl(fp32_grad, fp32_numel, square_norm, stream, + cub_tmp_buffer); + VLOG(10) << "FP32 square L2-Norm: " + << FlattenToString(square_norm, 1, cub_tmp_buffer->GetPlace()); + } + + if (fp16_numel > 0) { + float *fp16_square_norm = fp32_numel > 0 ? square_norm + 1 : square_norm; + GetSquareGradNormImpl(fp16_grad, fp16_numel, fp16_square_norm, stream, + cub_tmp_buffer); + VLOG(10) << "FP16 square L2-Norm: " + << FlattenToString(fp16_square_norm, 1, + cub_tmp_buffer->GetPlace()); + if (fp32_numel > 0) { + AddToCUDAKernel<<<1, 1, 0, stream>>>(fp16_square_norm, square_norm); + VLOG(10) << "FP32+FP16 square L2-Norm: " + << FlattenToString(square_norm, 1, cub_tmp_buffer->GetPlace()); + } + } + VLOG(10) << "GetSquareGradNorm ends, fp32_numel = " << fp32_numel + << " , fp16_numel = " << fp16_numel; +} + +template +std::string NumToString(T x) { + std::stringstream ss; + ss << x; + return ss.str(); +} + +template +static std::string GetMinMaxStr(const T *x, size_t n, + const platform::Place &place) { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(place), true, + platform::errors::InvalidArgument("Only support CUDAPlace currently.")); + + auto *dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto stream = dev_ctx->stream(); + + memory::Buffer ret_buffer(place); + T *ret = ret_buffer.Alloc(2); + + if (n > 0) { + memory::Buffer cub_buffer(place); + CubDeviceReduce(x, ret, n, cub::Min(), std::numeric_limits::max(), + stream, &cub_buffer); + CubDeviceReduce(x, ret + 1, n, cub::Max(), std::numeric_limits::lowest(), + stream, &cub_buffer); + T ret_cpu[2]; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(&ret_cpu[0], ret, 2 * sizeof(T), + hipMemcpyDeviceToHost, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret_cpu[0], ret, 2 * sizeof(T), + cudaMemcpyDeviceToHost, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); +#endif + return std::string("{\"min\": ") + NumToString(ret_cpu[0]) + + " , \"max\": " + NumToString(ret_cpu[1]) + "}"; + } else { + return "{\"min\": null, \"max\": null}"; + } +} + +struct VisitDTypeFunctor { + VisitDTypeFunctor(const framework::Tensor *x, std::string *s) + : x_(x), s_(s) {} + + template + void apply() const { + *s_ = GetMinMaxStr(x_->template data(), x_->numel(), x_->place()); + } + + private: + const framework::Tensor *x_; + std::string *s_; +}; + +static std::string GetMinMaxStr(const framework::Tensor *x) { + if (x == nullptr) return "null"; + if (!x->IsInitialized()) return "not_inited"; + if (!platform::is_gpu_place(x->place())) return "CPUTensor"; + std::string str; + VisitDTypeFunctor functor(x, &str); + pten::VisitDataType(x->dtype(), functor); + return str; +} + +static void PrintAllMinMaxRange(const framework::ExecutionContext &ctx, + bool only_inputs) { + if (!VLOG_IS_ON(1)) return; + for (const auto &pair : ctx.GetOp().Inputs()) { + const auto &key = pair.first; + const auto tensors = ctx.MultiInput(key); + size_t n = tensors.size(); + for (size_t i = 0; i < n; ++i) { + VLOG(1) << "Input(" << key + ")[" << i << "] = " << pair.second[i] + << " , " << GetMinMaxStr(tensors[i]); + } + } + + if (only_inputs) return; + for (const auto &pair : ctx.GetOp().Outputs()) { + const auto &key = pair.first; + const auto tensors = ctx.MultiOutput(key); + size_t n = tensors.size(); + for (size_t i = 0; i < n; ++i) { + VLOG(1) << "Output(" << key + ")[" << i << "] = " << pair.second[i] + << " , " << GetMinMaxStr(tensors[i]); + } + } +} + +static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel, + const platform::float16 *fp16_grad, + int fp16_numel, float *nan_inf_flag, + gpuStream_t stream, + memory::Buffer *cub_tmp_buffer) { + bool *fp32_has_nan_inf = nullptr; + bool *fp16_has_nan_inf = nullptr; + if (fp32_numel > 0) { + fp32_has_nan_inf = reinterpret_cast(nan_inf_flag + 1); + cub::TransformInputIterator, const float *> + iter(fp32_grad, IsNanInfFunctor()); + CubDeviceReduce(iter, fp32_has_nan_inf, fp32_numel, OrFunctor(), false, + stream, cub_tmp_buffer); + } + + if (fp16_numel > 0) { + fp16_has_nan_inf = reinterpret_cast(nan_inf_flag + 1) + 1; + cub::TransformInputIterator, + const platform::float16 *> + iter(fp16_grad, IsNanInfFunctor()); + CubDeviceReduce(iter, fp16_has_nan_inf, fp16_numel, OrFunctor(), false, + stream, cub_tmp_buffer); + } + + if (fp32_has_nan_inf && fp16_has_nan_inf) { + SetNanInfValueCUDAKernelTwoFlag<<<1, 1, 0, stream>>>( + fp32_has_nan_inf, fp16_has_nan_inf, nan_inf_flag); + } else if (fp32_has_nan_inf) { + SetNanInfValueCUDAKernelOneFlag<<<1, 1, 0, stream>>>(fp32_has_nan_inf, + nan_inf_flag); + } else { + SetNanInfValueCUDAKernelOneFlag<<<1, 1, 0, stream>>>(fp16_has_nan_inf, + nan_inf_flag); + } +} + +template +static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { + static_assert(!std::is_same::value, "T cannot be void."); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream)); +#endif +} + +template +class DistributedFusedLambOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto &dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + auto place = dev_ctx.GetPlace(); + + // Step 1: Get fp16 param and grad tensors + int64_t fp16_numel; + auto *fp16_param = GetSameInOutTensorPtr( + ctx, place, "FP16FusedParam", "FP16FusedParamOut", &fp16_numel); + bool has_fp16_param = (fp16_numel > 0); + const platform::float16 *fp16_grad = nullptr; + if (has_fp16_param) { + fp16_grad = GetInputTensorPtr(ctx, "FP16FusedGrad"); + } else { + fp16_param = nullptr; + } + + // Step 2: Get fp32 param and grad tensors + int64_t fp32_numel = 0; + auto *fp32_param = GetSameInOutTensorPtr( + ctx, place, "FP32FusedParam", "FP32FusedParamOut", &fp32_numel); + PADDLE_ENFORCE_GE(fp32_numel, fp16_numel, + platform::errors::InvalidArgument( + "The element number in FP32FusedParam should be not " + "less than FP16FusedParam.")); + + fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and + // fp16 master weight + bool has_fp32_param = (fp32_numel > 0); + const float *fp32_grad = nullptr; + if (has_fp32_param) { + fp32_grad = GetInputTensorPtr(ctx, "FP32FusedGrad"); + } else { + PADDLE_ENFORCE_EQ( + has_fp16_param, true, + platform::errors::InvalidArgument( + "Either FP32FusedGrad or FP16FusedGrad cannot be NULL.")); + } + + auto numel = fp32_numel + fp16_numel; + VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel + << " , fp16_numel = " << fp16_numel; + + // The NVIDIA cub library does not support number > INT32_MAX + PADDLE_ENFORCE_LE(numel, std::numeric_limits::max(), + platform::errors::Unimplemented( + "Too many parameter number. Only <= %d is supported.", + std::numeric_limits::max())); + + // Step 3: Get FusedIndices, ParamInfo + const auto *indices = GetInputTensorPtr(ctx, "FusedIndices"); + const auto *param_info_tensor = GetInputTensorPtr(ctx, "ParamInfo"); + auto fp32_local_start_idx = param_info_tensor[0]; + auto fp32_local_param_num = param_info_tensor[1]; + auto fp32_global_param_num = param_info_tensor[2]; + auto fp16_local_start_idx = param_info_tensor[3]; + auto fp16_local_param_num = param_info_tensor[4]; + auto fp16_global_param_num = param_info_tensor[5]; + + auto local_param_num = fp32_local_param_num + fp16_local_param_num; + auto param_num = fp32_global_param_num + fp16_global_param_num; + PADDLE_ENFORCE_LE(local_param_num, param_num, + platform::errors::InvalidArgument( + "The local parameter number should not exceed the " + "global parameter number.")); + VLOG(1) << "local_param_num = " << local_param_num + << " , global_param_num = " << param_num + << " , fp32_local_start_idx = " << fp32_local_start_idx + << " , fp32_local_param_num = " << fp32_local_param_num + << " , fp32_global_param_num = " << fp32_global_param_num + << " , fp16_local_start_idx = " << fp16_local_start_idx + << " , fp16_local_param_num = " << fp16_local_param_num + << " , fp16_global_param_num = " << fp16_global_param_num; + + // Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow, + // WeightDecay, GlobalScale, FoundInf + const auto *global_scale = GetInputTensorPtr(ctx, "GlobalScale"); + const auto *lr = GetInputTensorPtr(ctx, "LearningRate"); + int64_t partial_numel = 0; + auto *moment1 = GetSameInOutTensorPtr(ctx, place, "Moment1", + "Moment1Out", &partial_numel); + + PADDLE_ENFORCE_EQ(numel % partial_numel, 0, + platform::errors::InvalidArgument( + "The total parameter number %d should be divided " + "exactly by the element number %d of Moment1.", + numel, partial_numel)); + + int64_t num_devices = numel / partial_numel; + VLOG(1) << "num_devices = " << num_devices + << " , partial_numel = " << partial_numel; + + PADDLE_ENFORCE_EQ(fp32_numel % num_devices, 0, + platform::errors::InvalidArgument( + "The fp32 parameter number %d should be divided " + "exactly by the device number %d.", + fp32_numel, num_devices)); + PADDLE_ENFORCE_EQ(fp16_numel % num_devices, 0, + platform::errors::InvalidArgument( + "The fp16 parameter number %d should be divided " + "exactly by the device number %d.", + fp16_numel, num_devices)); + + auto *moment2 = + GetSameInOutTensorPtr(ctx, place, "Moment2", "Moment2Out"); + auto *beta1pow = + GetSameInOutTensorPtr(ctx, place, "Beta1Pow", "Beta1PowOut"); + auto *beta2pow = + GetSameInOutTensorPtr(ctx, place, "Beta2Pow", "Beta2PowOut"); + const float *weight_decay = GetInputTensorPtr(ctx, "WeightDecay"); + + auto *found_inf_t = ctx.Output("FoundInf"); + found_inf_t->Resize({1}); + auto *found_inf = found_inf_t->mutable_data(place); + + // Step 5: Get attributes beta1, beta2, epsilon, max_grad_norm, ring_id, + // use_master_param_norm, is_grad_scaled_by_nranks + auto beta1 = ctx.Attr("beta1"); + auto beta2 = ctx.Attr("beta2"); + auto epsilon = ctx.Attr("epsilon"); + auto max_global_grad_norm = ctx.Attr("max_global_grad_norm"); + auto clip_after_allreduce = ctx.Attr("clip_after_allreduce"); + auto ring_id = ctx.Attr("ring_id"); + auto use_master_param_norm = ctx.Attr("use_master_param_norm"); + auto is_grad_scaled_by_nranks = ctx.Attr("is_grad_scaled_by_nranks"); + VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm + << " , clip_after_allreduce = " << clip_after_allreduce + << " , use_master_param_norm = " << use_master_param_norm + << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks; + + // Step 6: allreduce + global norm gradient clip + int rank = 0; + ncclComm_t comm = nullptr; + if (num_devices > 1) { + auto *nccl_comm_handle = + platform::NCCLCommContext::Instance().Get(ring_id, place); + comm = nccl_comm_handle->comm(); + rank = nccl_comm_handle->rank(); + } + + memory::Buffer grad_norm_square_buffer(place); + auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc(2); + memory::Buffer cub_tmp_buffer(place); + + memory::Buffer sum_grad_buffer(place); + float *fp32_sum_grad; + platform::float16 *fp16_sum_grad; + auto fp32_numel_each_device = fp32_numel / num_devices; + auto fp16_numel_each_device = fp16_numel / num_devices; + if (num_devices > 1) { + auto ptr = sum_grad_buffer.Alloc( + fp32_numel_each_device * sizeof(float) + + fp16_numel_each_device * sizeof(platform::float16)); + fp32_sum_grad = has_fp32_param ? reinterpret_cast(ptr) : nullptr; + fp16_sum_grad = has_fp16_param + ? reinterpret_cast( + ptr + fp32_numel_each_device * sizeof(float)) + : nullptr; + } else { + // NOTE: The const_cast here is not important. The fp32_sum_grad and + // fp16_sum_grad would not be changed when num_devices == 1 + // But if I do not perform const_cast here, there would be more + // if-else codes (num_devices > 1) when I write the following code. + // So I prefer to use const_cast to unify the following code to reduce + // the if-else codes. + fp32_sum_grad = const_cast(fp32_grad); + fp16_sum_grad = const_cast(fp16_grad); + } + + float rescale_grad = 1.0f; + if (!is_grad_scaled_by_nranks) { + rescale_grad /= num_devices; + } + + if (max_global_grad_norm > 0) { + if (clip_after_allreduce) { + // (1) ReduceScater first + NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, + fp32_numel_each_device, num_devices, comm, + stream, dev_ctx); + NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, + fp16_numel_each_device, num_devices, comm, + stream, dev_ctx); + // (2) Calculate the global grad norm + GetSquareGradNorm(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad, + fp16_numel_each_device, fp32_square_grad_norm, stream, + &cub_tmp_buffer); + VLOG(1) << "Grad square norm before all reduce: " + << FlattenToString(fp32_square_grad_norm, 1, place); + if (num_devices > 1) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, + ncclSum, comm, stream)); + } + VLOG(1) << "Grad square norm after all reduce: " + << FlattenToString(fp32_square_grad_norm, 1, place); + } else { + // (1) Calculate the local grad norm + GetSquareGradNorm(fp32_grad, fp32_numel, fp16_grad, fp16_numel, + fp32_square_grad_norm, stream, &cub_tmp_buffer); + VLOG(1) << "Grad square norm before all reduce: " + << FlattenToString(fp32_square_grad_norm, 1, place); + // (2) Calculate the gradient clip scale + float *fp32_scale = nullptr; + platform::float16 *fp16_scale = nullptr; + if (has_fp32_param && has_fp16_param) { + auto *ptr = cub_tmp_buffer.Alloc(sizeof(float) + + sizeof(platform::float16)); + fp32_scale = reinterpret_cast(ptr); + fp16_scale = + reinterpret_cast(ptr + sizeof(float)); + } else if (has_fp32_param) { + fp32_scale = cub_tmp_buffer.Alloc(1); + } else { + fp16_scale = cub_tmp_buffer.Alloc(1); + } + + float clip_scale = 1.0f; + if (is_grad_scaled_by_nranks) { + clip_scale *= num_devices; + } + CalcGradNormClipBeforeAllReduceScale< + float, platform::float16><<<1, 1, 0, stream>>>( + global_scale, max_global_grad_norm, fp32_square_grad_norm, + fp32_scale, fp16_scale, clip_scale); + VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place); + if (num_devices > 1) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, + ncclSum, comm, stream)); + } + // (3) Do ReduceScatter with scale + NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, + fp32_numel_each_device, num_devices, comm, + stream, dev_ctx, fp32_scale); + NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, + fp16_numel_each_device, num_devices, comm, + stream, dev_ctx, fp16_scale); + // (4) mark max_global_grad_norm as 0, meaning that clip has been + // already performed + max_global_grad_norm = 0; + } + } else { + NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, + fp32_numel_each_device, num_devices, comm, + stream, dev_ctx); + NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, + fp16_numel_each_device, num_devices, comm, + stream, dev_ctx); + CheckHasNanInfGrad(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad, + fp16_numel_each_device, fp32_square_grad_norm, stream, + &cub_tmp_buffer); + if (num_devices > 1) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, + ncclSum, comm, stream)); + } + max_global_grad_norm = 0; + } + VLOG(10) << "ReduceScatter done"; + + // Step 7: update the moment1, moment2. Calcuate the trust_ratio_div + memory::Buffer trust_ratio_div_buffer(place); + auto *trust_ratio_div = trust_ratio_div_buffer.Alloc(partial_numel); + auto fp32_offset = rank * fp32_numel_each_device; + auto fp16_offset = rank * fp16_numel_each_device; + if (has_fp32_param) { + auto config = + platform::GetGpuLaunchConfig1D(dev_ctx, fp32_numel_each_device); + VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts"; + UpdateLambMoment<<>>( + fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm, + global_scale, indices + fp32_offset, weight_decay, beta1pow, beta2pow, + moment1, moment2, trust_ratio_div, beta1, beta2, epsilon, + max_global_grad_norm, fp32_numel_each_device, rescale_grad); + VLOG(10) << "Update FP32 Moment and TrustRatioDiv done"; + } + float *master_param = nullptr; + if (has_fp16_param) { + master_param = fp32_param + fp32_numel; + auto config = + platform::GetGpuLaunchConfig1D(dev_ctx, fp16_numel_each_device); + VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts"; + UpdateLambMoment<<>>( + master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm, + global_scale, indices + fp32_numel + fp16_offset, weight_decay, + beta1pow, beta2pow, moment1 + fp32_numel_each_device, + moment2 + fp32_numel_each_device, + trust_ratio_div + fp32_numel_each_device, beta1, beta2, epsilon, + max_global_grad_norm, fp16_numel_each_device, rescale_grad); + VLOG(10) << "Update FP16 Moment and TrustRatioDiv done"; + } + + VLOG(10) << "Update Moment and TrustRatioDiv done hehahaha"; + + // Step 8: calculate L2-Norm square of parameter and trust_ratio_div + memory::Buffer square_norm_buffer(place); + auto *param_square_norm = square_norm_buffer.Alloc(2 * param_num); + auto *trust_ratio_div_square_norm = param_square_norm + param_num; + + auto *fused_offsets_t = ctx.Input("FusedParamOffsets"); + auto *fused_offsets = fused_offsets_t->data(); + auto *fp32_partial_fused_offsets_t = + ctx.Input("FP32ShardFusedParamOffsets"); + const auto *fp32_partial_fused_offsets = + fp32_partial_fused_offsets_t->data(); + auto *fp16_partial_fused_offsets_t = + ctx.Input("FP16ShardFusedParamOffsets"); + const auto *fp16_partial_fused_offsets = + fp16_partial_fused_offsets_t->data(); + + VLOG(1) << "FusedParamOffsets: " + << FlattenToString(fused_offsets, fused_offsets_t->numel(), place); + VLOG(1) << "FP32ShardFusedParamOffsets: " + << FlattenToString(fp32_partial_fused_offsets, + fp32_partial_fused_offsets_t->numel(), place); + VLOG(1) << "FP16ShardFusedParamOffsets: " + << FlattenToString(fp16_partial_fused_offsets, + fp16_partial_fused_offsets_t->numel(), place); + + if (num_devices > 1) { + if (use_master_param_norm) { + FillZeroWithPtr(param_square_norm + fp32_global_param_num, + 2 * param_num - fp32_global_param_num, stream); + } else { + FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream); + } + } + CubDeviceSegmentedSquareNorm(fp32_param, param_square_norm, + fp32_global_param_num, fused_offsets, 0, + stream, &cub_tmp_buffer); + if (use_master_param_norm) { + CubDeviceSegmentedSquareNorm( + master_param + fp16_offset, param_square_norm + fp16_local_start_idx, + fp16_local_param_num, fp16_partial_fused_offsets, 0, stream, + &cub_tmp_buffer); + } else { + // NOTE: extra computation is performed. We can improve this performance + // if needed in the future. + CubDeviceSegmentedSquareNorm( + fp16_param, param_square_norm + fp32_global_param_num, + fp16_global_param_num, fused_offsets + fp32_global_param_num, + static_cast(fp32_numel), stream, &cub_tmp_buffer); + } + + CubDeviceSegmentedSquareNorm( + trust_ratio_div, trust_ratio_div_square_norm + fp32_local_start_idx, + fp32_local_param_num, fp32_partial_fused_offsets, 0, stream, + &cub_tmp_buffer); + CubDeviceSegmentedSquareNorm( + trust_ratio_div + fp32_numel_each_device, + trust_ratio_div_square_norm + fp16_local_start_idx, + fp16_local_param_num, fp16_partial_fused_offsets, 0, stream, + &cub_tmp_buffer); + + VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: " + << FlattenToString(trust_ratio_div_square_norm, param_num, place); + if (num_devices > 1) { + if (use_master_param_norm) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + param_square_norm + fp32_global_param_num, + param_square_norm + fp32_global_param_num, + 2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum, comm, + stream)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + trust_ratio_div_square_norm, trust_ratio_div_square_norm, param_num, + ncclFloat32, ncclSum, comm, stream)); + } + VLOG(10) << "ncclAllReduce done"; + } + + LogParamAndTrustRatioDivSquareNorm<1>(ctx, param_square_norm, + trust_ratio_div_square_norm); + VLOG(10) << "Calculate L2-Norm of Param and TrustRatioDiv done"; + + // Step 9: update parameter, beta1pow, beta2pow. All gather parameters. + if (has_fp32_param) { + LambUpdateParamAndBetaPows( + dev_ctx, trust_ratio_div, lr, indices + fp32_offset, + param_square_norm, trust_ratio_div_square_norm, fp32_square_grad_norm, + &beta1pow, &beta2pow, &found_inf, beta1, beta2, + fp32_numel_each_device, fp32_param + fp32_offset, nullptr, stream); + if (num_devices > 1) { + // ncclAllGather + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + fp32_param + fp32_offset, fp32_param, fp32_numel_each_device, + ncclFloat32, comm, stream)); + } + } + if (has_fp16_param) { + LambUpdateParamAndBetaPows( + dev_ctx, trust_ratio_div + fp32_numel_each_device, lr, + indices + fp32_numel + fp16_offset, param_square_norm, + trust_ratio_div_square_norm, fp32_square_grad_norm, &beta1pow, + &beta2pow, &found_inf, beta1, beta2, fp16_numel_each_device, + fp16_param + fp16_offset, master_param + fp16_offset, stream); + + if (num_devices > 1) { + // ncclAllGather + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + fp16_param + fp16_offset, fp16_param, fp16_numel_each_device, + ncclFloat16, comm, stream)); + } + } + VLOG(10) << "Update Param done"; + + VLOG(1) << "IsFinite: " << IsFinite(dev_ctx, fp32_square_grad_norm); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "distributed_fused_lamb op should be used with NCCL/RCCL.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + distributed_fused_lamb, + ops::DistributedFusedLambOpKernel); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h new file mode 100644 index 00000000000..88b2beb185e --- /dev/null +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h @@ -0,0 +1,34 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +template +class DistributedFusedLambOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "The distributed_fused_lamb operator does not support CPU yet.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index 4bcacd45ff6..e8cf33d820b 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/squared_l2_norm.h" +#include "paddle/fluid/operators/tensor_to_string.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/kernels/funcs/algorithm.h" #include "paddle/pten/kernels/funcs/eigen/extensions.h" @@ -658,6 +659,16 @@ class LambOpKernel : public framework::OpKernel { math::SquaredL2Norm(dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); + if (VLOG_IS_ON(1)) { + const auto& name = ctx.GetOp().Input("Param"); + auto pn = ToVector(p_norm_ptr, 1, dev_ctx.GetPlace()); + auto tn = ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace()); + auto dtype = + framework::DataTypeToString(framework::DataTypeTrait::DataType()); + VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0] + << " , tn = " << tn[0]; + } + #define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \ do { \ LambParamUpateFunctor \ diff --git a/paddle/fluid/operators/tensor_to_string.h b/paddle/fluid/operators/tensor_to_string.h new file mode 100644 index 00000000000..bd9e7f6219b --- /dev/null +++ b/paddle/fluid/operators/tensor_to_string.h @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace operators { + +template +static const std::vector &ToVector(const std::vector &vec) { + return vec; +} + +template +static std::vector ToVector(const T *x, size_t n, + const platform::Place &place) { +#ifdef __NVCC__ + if (platform::is_gpu_place(place)) { + using CopyT = typename std::conditional::value, + uint8_t, T>::type; + std::vector cpu_x(n); + auto *dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + memory::Copy(platform::CPUPlace(), cpu_x.data(), place, x, n * sizeof(T), + dev_ctx->stream()); + dev_ctx->Wait(); + return std::vector(cpu_x.data(), cpu_x.data() + n); + } +#endif + return std::vector(x, x + n); +} + +template +static std::vector ToVector(const framework::Tensor &src) { + if (!src.IsInitialized()) { + return {}; + } + return ToVector(src.template data(), src.numel(), src.place()); +} + +template +static std::string FlattenToString(Args &&... args) { + const auto &vec = ToVector(std::forward(args)...); + return "[" + string::join_strings(vec, ',') + "]"; +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 082a72af799..55556cf2640 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -36,12 +36,35 @@ __all__ = [ 'ClipGradByNorm', 'ClipGradByGlobalNorm' ] +_clip_by_global_norm_using_mp_type_flag = False + + +def _clip_by_global_norm_using_mp_type(*args): + global _clip_by_global_norm_using_mp_type_flag + assert len(args) <= 1 + if len(args) == 1: + assert isinstance(args[0], bool) + old_value = _clip_by_global_norm_using_mp_type_flag + _clip_by_global_norm_using_mp_type_flag = args[0] + return old_value + else: + return _clip_by_global_norm_using_mp_type_flag + + +def _cast_to_mp_type_if_enabled(x): + if x.dtype == core.VarDesc.VarType.FP16 and _clip_by_global_norm_using_mp_type( + ): + return x.astype(core.VarDesc.VarType.FP32) + else: + return x + def _squared_l2_norm(x): r""" This OP returns the squared L2 norm of a tensor. """ + x = _cast_to_mp_type_if_enabled(x) if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16: square = layers.square(x) sum_square = layers.reduce_sum(square) @@ -595,9 +618,10 @@ class ClipGradByGlobalNorm(ClipGradBase): continue with p.block.program._optimized_guard([p, g]): + new_g = _cast_to_mp_type_if_enabled(g) # inplace - scale_input = (scale_var.astype('float16') - if g.dtype == core.VarDesc.VarType.FP16 and + scale_input = (scale_var.astype('float16') if + new_g.dtype == core.VarDesc.VarType.FP16 and scale_var.dtype != core.VarDesc.VarType.FP16 else scale_var) # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g @@ -607,9 +631,18 @@ class ClipGradByGlobalNorm(ClipGradBase): block = default_main_program().current_block() block.append_op( type='elementwise_mul', - inputs={'X': g, + inputs={'X': new_g, 'Y': scale_input}, - outputs={'Out': g}) + outputs={'Out': new_g}) + if new_g is not g: + block.append_op( + type='cast', + inputs={'X': new_g}, + outputs={'Out': g}, + attrs={ + 'in_dtype': new_g.dtype, + 'out_dtype': g.dtype + }) param_new_grad_name_dict[p.name] = g.name params_and_grads.append((p, g)) diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index b737b14aa6d..c6e2bcb8b1a 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -108,6 +108,9 @@ class OptimizerWithMixedPrecision(object): """ return self._scaled_loss + def _supports_check_nan_inf(self): + return getattr(self._optimizer, "_supports_check_nan_inf", False) + def _init_amp_var(self): self._loss_scaling = layers.create_global_var( name=unique_name.generate("loss_scaling"), @@ -202,8 +205,34 @@ class OptimizerWithMixedPrecision(object): params_grads = self._optimizer.backward( self._scaled_loss, startup_program, parameter_list, no_grad_set, callbacks) + if self._supports_check_nan_inf(): + self._add_cast_ops_to_startup_program(startup_program) return params_grads + def _add_cast_ops_to_startup_program(self, startup_program): + names = list(self._to_fp16_var_names) if self._to_fp16_var_names else [] + names.sort() + startup_program = default_startup_program( + ) if startup_program is None else startup_program + block = startup_program.global_block() + param_names = [p.name for p in block.all_parameters()] + for name in names: + if name not in param_names: + continue + + tmp = block.create_var(dtype=core.VarDesc.VarType.FP32) + block.append_op( + type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}) + block.append_op( + type='cast', + inputs={'X': [tmp]}, + outputs={'Out': [name]}, + attrs={ + 'in_dtype': core.VarDesc.VarType.FP32, + 'out_dtype': core.VarDesc.VarType.FP16, + }) + self._to_fp16_var_names = None + def amp_init(self, place, scope=None, @@ -297,13 +326,47 @@ class OptimizerWithMixedPrecision(object): if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0: return self._optimizer.apply_gradients(params_grads) + if self._supports_check_nan_inf(): + self._optimizer._set_scale(self._loss_scaling) + optimize_ops = self._optimizer.apply_gradients(params_grads) + found_inf = self._optimizer._found_inf + self._add_dynamic_loss_scaling(params_grads, found_inf) + return optimize_ops + + found_inf = self._check_finite_and_unscale(params_grads) + if self._use_dynamic_loss_scaling: + self._add_dynamic_loss_scaling(params_grads, found_inf) + + # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow + # With fleet, optimizers are nested and the real optimizer set by user is the inner most one. + real_optimizer = self._optimizer + while hasattr(real_optimizer, "inner_opt"): + real_optimizer = real_optimizer.inner_opt + if isinstance(real_optimizer, (paddle.fluid.optimizer.Adam, + paddle.optimizer.AdamW)): + # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we + # copy it in advance to avoid multiple time copies. + with self._train_program._optimized_guard([]): + found_inf = paddle.tensor.creation._memcpy(found_inf, + paddle.CPUPlace()) + real_optimizer._set_auxiliary_var('found_inf', found_inf) + elif hasattr(real_optimizer, "_set_auxiliary_var"): + real_optimizer._set_auxiliary_var('found_inf', found_inf) + optimize_ops = self._optimizer.apply_gradients(params_grads) + return optimize_ops + + def _split_grads(self, params_grads): grads = [g for _, g in params_grads] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] assert len(fp32_grads) + len(fp16_grads) == len(grads), \ "Data types of all grads must be either fp16 or fp32." + return grads, fp32_grads, fp16_grads + def _check_finite_and_unscale(self, params_grads): + grads, fp32_grads, fp16_grads = self._split_grads(params_grads) found_infs = [] + if self._is_distributed: # if distributed, split check_finite_and_unscale to overlap # unscale with communication @@ -349,46 +412,37 @@ class OptimizerWithMixedPrecision(object): name="find_infinite_scale", float_status=self._float_status) - if self._use_dynamic_loss_scaling: - if self._is_distributed or self._use_pure_fp16: - with self._train_program._optimized_guard([]): - all_infs = layers.concat(found_infs) - found_inf = layers.reduce_any(all_infs) + if self._is_distributed or self._use_pure_fp16: + with self._train_program._optimized_guard([]): + all_infs = layers.concat(found_infs) + found_inf = layers.reduce_any(all_infs) - if self._use_pure_fp16: - stop_update = False - with self._train_program._optimized_guard([]): - if fp32_grads: - update_loss_scaling( - fp32_grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - stop_update=stop_update, - name="update_loss_scaling_fp32") - stop_update = True - if fp16_grads: - update_loss_scaling( - fp16_grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - stop_update=stop_update, - name="update_loss_scaling_fp16") - else: - with self._train_program._optimized_guard([]): + return found_inf + + def _add_dynamic_loss_scaling(self, params_grads, found_inf): + if self._supports_check_nan_inf(): + with self._train_program._optimized_guard([]): + update_loss_scaling( + [], + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=False, + name="update_loss_scaling") + return + + grads, fp32_grads, fp16_grads = self._split_grads(params_grads) + if self._use_pure_fp16: + stop_update = False + with self._train_program._optimized_guard([]): + if fp32_grads: update_loss_scaling( - grads, + fp32_grads, found_inf, self._loss_scaling, self._num_good_steps, @@ -397,24 +451,35 @@ class OptimizerWithMixedPrecision(object): self._decr_every_n_nan_or_inf, self._incr_ratio, self._decr_ratio, - name="update_loss_scaling") - # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow - # With fleet, optimizers are nested and the real optimizer set by user is the inner most one. - real_optimizer = self._optimizer - while hasattr(real_optimizer, "inner_opt"): - real_optimizer = real_optimizer.inner_opt - if isinstance(real_optimizer, (paddle.fluid.optimizer.Adam, - paddle.optimizer.AdamW)): - # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we - # copy it in advance to avoid multiple time copies. + stop_update=stop_update, + name="update_loss_scaling_fp32") + stop_update = True + if fp16_grads: + update_loss_scaling( + fp16_grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=stop_update, + name="update_loss_scaling_fp16") + else: with self._train_program._optimized_guard([]): - found_inf = paddle.tensor.creation._memcpy(found_inf, - paddle.CPUPlace()) - real_optimizer._set_auxiliary_var('found_inf', found_inf) - elif hasattr(real_optimizer, "_set_auxiliary_var"): - real_optimizer._set_auxiliary_var('found_inf', found_inf) - optimize_ops = self._optimizer.apply_gradients(params_grads) - return optimize_ops + update_loss_scaling( + grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + name="update_loss_scaling") def apply_optimize(self, loss, startup_program, params_grads): program = loss.block.program diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f11f894970d..0c81a0e9346 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -846,6 +846,8 @@ set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op test_inp test_parallel_executor_seresnext_base_gpu test_parallel_executor_seresnext_with_reduce_gpu test_parallel_executor_seresnext_with_fuse_all_reduce_gpu + test_distributed_fused_lamb_op_with_clip + test_distributed_fused_lamb_op_without_clip test_parallel_executor_fetch_isolated_var PROPERTIES LABELS "RUN_TYPE=DIST") @@ -974,6 +976,8 @@ set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_sub_op PROPERTIES TIMEOUT 120) set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120) +set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT 120) +set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120) set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py new file mode 100644 index 00000000000..e0529c5d5f8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py @@ -0,0 +1,309 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import paddle +import paddle.fluid.core as core +import paddle.distributed.fleet as fleet +from paddle.incubate import DistributedFusedLamb +from paddle.vision.models import resnet18 as resnet +from paddle.distributed.fleet.meta_optimizers.common import CollectiveHelper +from paddle.fluid.clip import ClipGradBase +import paddle.nn as nn +import numpy as np +import os +import unittest +from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, is_backward_op +from paddle.fluid.clip import _clip_by_global_norm_using_mp_type +import distutils + + +def get_role_maker(): + return fleet.PaddleCloudRoleMaker(is_collective=True) + + +def set_seed(seed): + paddle.seed(seed) + rank = paddle.distributed.get_rank() + np_seed = seed + rank + np.random.seed(np_seed) + + +def set_gradient_persistable(program): + block = program.global_block() + params = [] + grads = [] + for p in block.all_parameters(): + p_name = p.name + g_name = p_name + '@GRAD' + g = block.vars.get(g_name) + if g is None: + continue + g.persistable = True + params.append(p) + grads.append(g) + return params, grads + + +def prune_fwd_bwd_ops(program, start_idx): + for i in reversed(range(start_idx)): + program.global_block()._remove_op(i, sync=False) + program._sync_with_cpp() + + ops = program.global_block().ops + all_vars = set(program.global_block().vars.keys()) + for op in ops: + args = op.input_arg_names + op.output_arg_names + for arg in args: + if arg in all_vars: + all_vars.remove(arg) + + for var in all_vars: + program.global_block()._remove_var(var) + program._sync_with_cpp() + + +class GradClipDecorator(ClipGradBase): + def __init__(self, clip, clip_after_allreduce): + self.clip = clip + self.clip_after_allreduce = clip_after_allreduce + + def _dygraph_clip(self, params_grads): + raise NotImplementedError() + + def _insert_allreduce_ops(self, params_grads): + world_size = paddle.distributed.get_world_size() + if world_size == 1: + return + block = params_grads[0][0].block + scale = 1.0 / world_size + # scale = 1.0 + for p, g in params_grads: + block.append_op( + type='c_allreduce_sum', + inputs={'X': [g]}, + outputs={'Out': [g]}, + attrs={'ring_id': 0, + 'use_calc_stream': True}) + block.append_op( + type='scale', + inputs={'X': [g]}, + outputs={'Out': [g]}, + attrs={'scale': scale}) + + def _static_clip(self, params_grads): + if self.clip_after_allreduce: + self._insert_allreduce_ops(params_grads) + + params_grads = self.clip(params_grads) + if not self.clip_after_allreduce: + self._insert_allreduce_ops(params_grads) + return params_grads + + +class IdentityGradClip(ClipGradBase): + def _dygraph_clip(self, params_grads): + return params_grads + + def _static_clip(self, params_grads): + return params_grads + + +def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): + nranks = paddle.distributed.get_world_size() + + set_seed(1000) + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + with paddle.fluid.unique_name.guard(): + with paddle.static.amp.fp16_guard(): + image = paddle.static.data( + name='image', + shape=[None, 3, 224, 224], + dtype=paddle.float32) + label = paddle.static.data( + name='label', shape=[None, 1], dtype=paddle.int64) + model = resnet() + pred = model(image) + loss_fn = paddle.nn.loss.CrossEntropyLoss() + loss = loss_fn(pred, label) + + grad_clip = kwargs.get('grad_clip', None) + clip_after_allreduce = kwargs.get('clip_after_allreduce', True) + + if use_distributed_lamb: + optimizer_class = DistributedFusedLamb + kwargs = dict(kwargs) + kwargs['is_grad_scaled_by_nranks'] = False + kwargs['use_master_param_norm'] = use_master_param_norm + else: + optimizer_class = paddle.optimizer.Lamb + kwargs = dict(kwargs) + kwargs.pop('clip_after_allreduce', None) + kwargs.pop('alignment', None) + base_clip = grad_clip if grad_clip is not None else IdentityGradClip( + ) + kwargs['grad_clip'] = GradClipDecorator(base_clip, + clip_after_allreduce) + + optimizer = optimizer_class(**kwargs) + get_parameter = optimizer._get_parameter + amp_list = paddle.static.amp.AutoMixedPrecisionLists( + custom_white_list=[ + 'batch_norm', 'batch_norm_grad', 'conv2d', 'conv2d_grad' + ]) + if use_fp16: + if not use_distributed_lamb: + optimizer._multi_precision = True + optimizer = paddle.static.amp.decorate( + optimizer, + amp_list, + init_loss_scaling=1.0, + use_dynamic_loss_scaling=False, + use_pure_fp16=use_fp16, + use_fp16_guard=use_fp16) + + params_grads = optimizer.backward(loss, startup) + op_num = len(main.global_block().ops) + if use_fp16: + optimizer.apply_optimize(loss, startup, params_grads) + else: + optimizer.apply_gradients(params_grads) + + if nranks > 1: + collective_helper = CollectiveHelper(role_maker=get_role_maker()) + collective_helper.update_startup_program(startup) + set_gradient_persistable(startup) + params, grads = set_gradient_persistable(main) + prune_fwd_bwd_ops(main, op_num) + + def pd_dtype_to_np_dtype(pd_dtype): + if pd_dtype == paddle.float32: + return np.float32 + elif pd_dtype == paddle.float16: + return np.float16 + else: + raise ValueError("supported dtype {}".format(pd_dtype)) + + def gen_random_grad_tensor(grad): + np_dtype = pd_dtype_to_np_dtype(grad.dtype) + grad_np = np.random.random(size=grad.shape).astype(np_dtype) + grad_t = core.Tensor() + grad_t.set(grad_np, paddle.CPUPlace()) + return grad_t + + def reader(): + for _ in range(5): + yield dict( + [(grad.name, gen_random_grad_tensor(grad)) for grad in grads]) + + scope = paddle.static.Scope() + fetch_list = params + fetches = None + with paddle.static.scope_guard(scope): + dev_id = int(os.environ.get('FLAGS_selected_gpus', 0)) + place = paddle.CUDAPlace(dev_id) + exe = paddle.static.Executor(place) + exe.run(startup) + if use_fp16: + optimizer.amp_init(place) + + master_p_ts = [] + for p in params: + p_ts = get_parameter(p.name) + assert len(p_ts) == 2 + if p_ts[1] is not None: + master_p_ts.append(p_ts[1]) + if use_fp16: + assert len(master_p_ts) > 0 + else: + assert len(master_p_ts) == 0 + + for feed in reader(): + fetches = exe.run(main, feed=feed, fetch_list=fetch_list) + return fetches + + +class TestDistributedFusedLamb(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not paddle.is_compiled_with_cuda(): + return + + paddle.enable_static() + paddle.set_flags({'FLAGS_cudnn_deterministic': True}) + _clip_by_global_norm_using_mp_type(True) + fleet.init(role_maker=get_role_maker()) + + def config(self): + clip_after_allreduce = bool( + distutils.util.strtobool( + os.getenv('CLIP_AFTER_ALLREDUCE', 'True'))) + max_global_norm = float(os.getenv('MAX_GLOBAL_NORM', -1.0)) + print('clip_after_allreduce = {}, max_global_norm = {}'.format( + clip_after_allreduce, max_global_norm)) + return { + 'clip_after_allreduce': clip_after_allreduce, + 'grad_clip': paddle.nn.ClipGradByGlobalNorm(max_global_norm) + if max_global_norm > 0 else None, + } + + def run_main(self, use_fp16, use_master_param_norm=True): + if not paddle.is_compiled_with_cuda(): + return + + if not use_fp16: + self.assertTrue(use_master_param_norm) + + base_config = self.config() + config1 = dict(base_config) + config1['use_distributed_lamb'] = True + config1['use_fp16'] = use_fp16 + config1['use_master_param_norm'] = use_master_param_norm + + config2 = dict(base_config) + config2['use_distributed_lamb'] = False + config2['use_fp16'] = use_fp16 + config2['use_master_param_norm'] = use_master_param_norm + + result1 = run_model(**config1) + result2 = run_model(**config2) + self.assertEqual(len(result1), len(result2)) + + if use_fp16: + atol = 8e-4 if use_master_param_norm else 1e-3 + else: + atol = 1e-7 + for ret1, ret2 in zip(result1, result2): + max_diff = np.max(np.abs(ret1 - ret2)) + msg = 'max_diff = {} atol = {} when use_fp16 = {} , use_master_param_norm = {}'.format( + max_diff, atol, use_fp16, use_master_param_norm) + self.assertTrue(max_diff < atol, msg) + print(msg) + + def test_main(self): + self.run_main(use_fp16=False) + self.run_main(use_fp16=True, use_master_param_norm=True) + self.run_main(use_fp16=True, use_master_param_norm=False) + + touch_file_name = os.environ.get('SUCCESS_TOUCH_FILE') + if touch_file_name: + with open(touch_file_name, 'w') as f: + f.write('success') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py new file mode 100644 index 00000000000..060a790a6e5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shlex +import sys +import shutil +import unittest +import paddle + + +def get_test_file(): + dirname = os.path.dirname(os.path.abspath(__file__)) + return os.path.join(dirname, 'distributed_fused_lamb_test_base.py') + + +def remove_file_if_exists(file_name): + if not os.path.exists(file_name): + return + if os.path.isfile(file_name): + os.remove(file_name) + else: + shutil.rmtree(file_name) + + +def run_test(clip_after_allreduce=True, max_global_norm=-1.0): + if not paddle.is_compiled_with_cuda(): + return + if os.name == 'nt': + return + args = locals() + log_dir = 'log_{}'.format(os.getpid()) + cmd = [ + sys.executable, + '-u', + '-m', + 'paddle.distributed.launch', + '--log_dir', + log_dir, + get_test_file(), + ] + + cmd = ' '.join([shlex.quote(c) for c in cmd]) + + os.environ['CLIP_AFTER_ALLREDUCE'] = str(clip_after_allreduce) + os.environ['MAX_GLOBAL_NORM'] = str(max_global_norm) + + touch_file_env = 'SUCCESS_TOUCH_FILE' + touch_file_name = 'distributed_fused_lamb_touch_file_{}'.format(os.getpid()) + os.environ[touch_file_env] = touch_file_name + remove_file_if_exists(touch_file_name) + try: + assert os.system(cmd) == 0 and os.path.exists( + touch_file_name), 'Test failed when {}'.format(args) + finally: + remove_file_if_exists(touch_file_name) + remove_file_if_exists(log_dir) + + +class TestDistributedFusedLambWithClip(unittest.TestCase): + def test_1(self): + run_test(clip_after_allreduce=True, max_global_norm=0.01) + + def _test_2(self): + run_test(clip_after_allreduce=False, max_global_norm=0.01) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_without_clip.py b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_without_clip.py new file mode 100644 index 00000000000..dbd2d72fd2f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_without_clip.py @@ -0,0 +1,28 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test_distributed_fused_lamb_op_with_clip import run_test +import unittest + + +class TestDistributedFusedLambWithoutClip(unittest.TestCase): + def test_1(self): + run_test(clip_after_allreduce=True, max_global_norm=-1.0) + + def test_2(self): + run_test(clip_after_allreduce=False, max_global_norm=-1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index d637def4054..83dad710bad 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -14,6 +14,7 @@ from .optimizer import LookAhead # noqa: F401 from .optimizer import ModelAverage # noqa: F401 +from .optimizer import DistributedFusedLamb # noqa: F401 from .checkpoint import auto_checkpoint # noqa: F401 from ..fluid.layer_helper import LayerHelper # noqa: F401 from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 diff --git a/python/paddle/incubate/optimizer/__init__.py b/python/paddle/incubate/optimizer/__init__.py index d966d187f28..fd5332986ed 100644 --- a/python/paddle/incubate/optimizer/__init__.py +++ b/python/paddle/incubate/optimizer/__init__.py @@ -14,5 +14,6 @@ from .lookahead import LookAhead # noqa: F401 from .modelaverage import ModelAverage # noqa: F401 +from .distributed_fused_lamb import DistributedFusedLamb # noqa: F401 __all__ = [] diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py new file mode 100644 index 00000000000..74c481fb641 --- /dev/null +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -0,0 +1,305 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid import framework, core, layers, unique_name +from paddle.fluid.framework import Variable +from paddle.fluid.clip import ClipGradByGlobalNorm +from paddle.fluid.initializer import Constant +from paddle.fluid.layer_helper import LayerHelper +from paddle.optimizer import Optimizer +from paddle.distributed import get_rank, get_world_size +from paddle.fluid.executor import global_scope +from paddle.fluid.framework import name_scope +import numpy as np + + +class DistributedFusedLamb(Optimizer): + def __init__(self, + learning_rate=0.001, + lamb_weight_decay=0.01, + beta1=0.9, + beta2=0.999, + epsilon=1e-6, + parameters=None, + grad_clip=None, + exclude_from_weight_decay_fn=None, + clip_after_allreduce=True, + is_grad_scaled_by_nranks=True, + alignment=128, + use_master_param_norm=True, + name=None): + assert not framework.in_dygraph_mode( + ), "DistributedFusedLamb does not support dygraph mode" + super(DistributedFusedLamb, self).__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=None, + grad_clip=None, + name=name) + + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._weight_decay = lamb_weight_decay if lamb_weight_decay is not None else 0.0 + if grad_clip is not None: + assert isinstance( + grad_clip, ClipGradByGlobalNorm + ), "Only ClipGradByGlobalNorm is supported in DistributedFusedLamb" + max_global_grad_norm = grad_clip.clip_norm + else: + max_global_grad_norm = -1.0 + self._max_global_grad_norm = max_global_grad_norm + self._alignment = alignment if alignment is not None else -1 + self._clip_after_allreduce = clip_after_allreduce + self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks + self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn + self._scale = None + self._ring_id = 0 + self._use_master_param_norm = use_master_param_norm + self.helper = LayerHelper('distributed_fused_lamb') + self._supports_check_nan_inf = True # very import flag for AMP + + main_block = self.helper.main_program.global_block() + self._found_inf = main_block.create_var( + name=unique_name.generate('found_inf'), + shape=[1], + dtype=core.VarDesc.VarType.BOOL) + + self._param_to_master_param = {} + + def _set_scale(self, scale): + assert scale is not None + if not isinstance(scale, Variable): + scale = self._create_scale_from_constant(scale) + self._scale = scale + + def _create_scale_from_constant(self, value): + name = unique_name.generate('global_scale') + return layers.create_global_var( + name=name, + shape=[1], + dtype='float32', + value=float(value), + persistable=True) + + def _get_or_create_scale(self): + if self._scale is None: + self._scale = self._create_scale_from_constant(1.0) + return self._scale + + def _create_persistable_var(self, name=None, shape=[-1], dtype='float32'): + startup_block = self.helper.startup_program.global_block() + if name is not None: + name = unique_name.generate(name) + startup_var = startup_block.create_var( + name=name, + shape=shape, + dtype=dtype, + persistable=True, + stop_gradient=True) + main_block = self.helper.main_program.global_block() + main_var = main_block.create_var( + name=startup_var.name, + shape=startup_var.shape, + dtype=startup_var.dtype, + persistable=True, + stop_gradient=True) + return main_var + + def _get_parameter(self, name, scope=None): + if scope is None: + scope = global_scope() + + master_param = self._param_to_master_param.get(name) + assert master_param is not None + + master_param_t = scope.find_var(master_param).get_tensor() + assert master_param_t._dtype() == core.VarDesc.VarType.FP32 + + param_t = scope.find_var(name).get_tensor() + if param_t._dtype() == core.VarDesc.VarType.FP32: + assert param_t._ptr() == master_param_t._ptr() + return param_t, None + else: + assert param_t._dtype() == core.VarDesc.VarType.FP16 + assert param_t.shape() == master_param_t.shape() + return param_t, master_param_t + + def apply_optimize(self, params_grads): + self.apply_gradients(params_grads) + + def apply_gradients(self, params_grads): + flattened = [] + for p, g in params_grads: + flattened.extend([p, g]) + with flattened[0].block.program._optimized_guard(flattened), name_scope( + "optimizer"): + self._apply_gradients_impl(params_grads) + + def _apply_gradients_impl(self, params_grads): + for p, g in params_grads: + assert g.type == core.VarDesc.VarType.LOD_TENSOR, "Only support dense gradient" + g.persistable = True # the gradient must be persistable for fusion + + fp32_fused_param = self._create_persistable_var('fp32_fused_param') + fp32_fused_grad = self._create_persistable_var('fp32_fused_grad') + fp16_fused_param = self._create_persistable_var( + 'fp16_fused_param', dtype='float16') + fp16_fused_grad = self._create_persistable_var( + 'fp16_fused_grad', dtype='float16') + + master_params = [] + for p, g in params_grads: + master_p = self._create_persistable_var('master_weight') + self._param_to_master_param[p.name] = master_p.name + master_params.append(master_p) + + moment1 = self._create_persistable_var('moment1') + moment1.is_distributed = True + moment2 = self._create_persistable_var('moment2') + moment2.is_distributed = True + beta1pow = self._create_persistable_var('beta1pow') + beta2pow = self._create_persistable_var('beta2pow') + fused_indices = self._create_persistable_var( + 'fused_indices', dtype='int32') + weight_decay = self._create_persistable_var('weight_decay') + weight_decay.is_distributed = True + param_info = self._create_persistable_var('param_info', dtype='int32') + param_info.is_distributed = True + + fused_offsets = self._create_persistable_var('fused_offsets') + + fp32_partial_fused_offsets = self._create_persistable_var( + 'fp32_partial_fused_offsets', dtype='int32') + fp32_partial_fused_offsets.is_distributed = True + fp16_partial_fused_offsets = self._create_persistable_var( + 'fp16_partial_fused_offsets', dtype='int32') + fp16_partial_fused_offsets.is_distributed = True + + rank = get_rank() + nranks = get_world_size() + scale = self._get_or_create_scale() + + params = [p for p, _ in params_grads] + grads = [g for _, g in params_grads] + weight_decay_values = [self._weight_decay] * len(params) + if self._exclude_from_weight_decay_fn is not None: + for i, p in enumerate(params): + if self._exclude_from_weight_decay_fn(p): + weight_decay_values[i] = 0.0 + + startup_block = self.helper.startup_program.global_block() + for g in grads: + startup_block.create_var( + name=g.name, + type=g.type, + dtype=g.dtype, + persistable=g.persistable, + shape=g.shape) + + startup_block.append_op( + type='distributed_fused_lamb_init', + inputs={ + 'Param': params, + 'Grad': grads, + }, + outputs={ + 'FP32FusedParam': [fp32_fused_param], + 'FP32FusedGrad': [fp32_fused_grad], + 'FP16FusedParam': [fp16_fused_param], + 'FP16FusedGrad': [fp16_fused_grad], + 'Moment1': [moment1], + 'Moment2': [moment2], + 'Beta1Pow': [beta1pow], + 'Beta2Pow': [beta2pow], + 'FusedIndices': [fused_indices], + 'WeightDecay': [weight_decay], + 'GlobalScale': [scale], + 'ParamInfo': [param_info], + 'ParamOut': params, + 'MasterParamOut': master_params, + 'GradOut': grads, + 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], + 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], + 'FusedParamOffsets': [fused_offsets], + }, + attrs={ + 'alignment': self._alignment, + 'rank': rank, + 'nranks': nranks, + 'weight_decay': weight_decay_values, + 'moment1': 0.0, + 'moment2': 0.0, + 'beta1': self._beta1, + 'beta2': self._beta2, + }) + + main_block = self.helper.main_program.global_block() + self._create_global_learning_rate() + lr = None + for p_g in params_grads: + if lr is None: + lr = self._create_param_lr(p_g) + else: + new_lr = self._create_param_lr(p_g) + assert id(lr) == id( + new_lr + ), "The learning rate for each parameter should be the same" + assert lr is not None + + lamb_op = main_block.append_op( + type='distributed_fused_lamb', + inputs={ + 'FP32FusedParam': [fp32_fused_param], + 'FP32FusedGrad': [fp32_fused_grad], + 'FP16FusedParam': [fp16_fused_param], + 'FP16FusedGrad': [fp16_fused_grad], + 'LearningRate': [lr], + 'Moment1': [moment1], + 'Moment2': [moment2], + 'Beta1Pow': [beta1pow], + 'Beta2Pow': [beta2pow], + 'FusedIndices': [fused_indices], + 'WeightDecay': [weight_decay], + 'GlobalScale': [scale], + 'ParamInfo': [param_info], + 'Param': params, + 'Grad': grads, + 'FusedParamOffsets': [fused_offsets], + 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], + 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], + }, + outputs={ + 'FP32FusedParamOut': [fp32_fused_param], + 'FP16FusedParamOut': [fp16_fused_param], + 'Moment1Out': [moment1], + 'Moment2Out': [moment2], + 'Beta1PowOut': [beta1pow], + 'Beta2PowOut': [beta2pow], + 'ParamOut': params, + 'GradOut': grads, + 'FoundInf': [self._found_inf], + }, + attrs={ + 'beta1': self._beta1, + 'beta2': self._beta2, + 'epsilon': self._epsilon, + 'max_global_grad_norm': self._max_global_grad_norm, + 'clip_after_allreduce': self._clip_after_allreduce, + 'rank': rank, + 'ring_id': self._ring_id, + 'use_master_param_norm': self._use_master_param_norm, + 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, + }) + return [lamb_op] -- GitLab