diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc index aeda36e537c7f0c1df971c22360f0c0b6d80e26c..59313bc95af0ae0304cf8cecaaf21372e382c973 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { @@ -116,9 +117,3 @@ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init, ops::DistributedFusedLambInitOp, ops::DistributedFusedLambInitOpMaker); - -PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb_init, - CPU, - ALL_LAYOUT, - ops::DistributedFusedLambInitOpKernel, - float) {} diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu deleted file mode 100644 index 8841544366e87c876a58c2b955a318c5d8b0b97a..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu +++ /dev/null @@ -1,797 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/optimizers/cast_with_ptr.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/kernels/funcs/algorithm.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/tensor_to_string.h" - -namespace paddle { -namespace operators { - -using phi::funcs::FlattenToString; -using phi::funcs::ToVector; - -struct ParamGradInfo { - phi::DenseTensor *param_t{nullptr}; - phi::DenseTensor *grad_t{nullptr}; - size_t idx{0}; - size_t numel{0}; - size_t numel_with_padding{0}; - size_t numel_offset{0}; -}; - -static std::ostream &operator<<(std::ostream &os, const ParamGradInfo &info) { - return os << "{Param(" << info.param_t << "),Grad(" << info.grad_t << "),idx(" - << info.idx << "),numel(" << info.numel << "),numel_with_padding(" - << info.numel_with_padding << "),numel_offset(" << info.numel_offset - << "),padding(" << info.numel_offset + info.numel_with_padding - << "-" << info.numel_offset + info.numel << "=" - << info.numel_with_padding - info.numel << ")}"; -} - -struct ParamGradInfoNumelOffsetCompFunctor { - bool operator()(const ParamGradInfo &x, const ParamGradInfo &y) const { - return x.numel_offset < y.numel_offset; - } - - bool operator()(const ParamGradInfo &x, size_t y) const { - return x.numel_offset < y; - } - - bool operator()(size_t x, const ParamGradInfo &y) const { - return x < y.numel_offset; - } - - bool operator()(size_t x, size_t y) const { return x < y; } -}; - -static size_t GetAlignSize(size_t n, size_t alignment) { - auto remainder = n % alignment; - return remainder == 0 ? n : n + alignment - remainder; -} - -// Shard the ParamGradInfo list by the numel size [start_size, end_size) -// The final results should be: -// -// start_size = sum(infos[0:i].numel_with_padding) + start_numel_offset, where -// start_numel_offset <= infos[i].numel_with_padding -// -// end_size = sum(infos[0:j].numel_with_padding) + end_numel_offset, where -// end_numel_offset <= infos[j].numel_with_padding -static void GetParamGradShardInfo(const std::vector &infos, - size_t start_size, - size_t end_size, - size_t *start_idx, - size_t *end_idx, - size_t *start_numel_offset, - size_t *end_numel_offset) { - VLOG(10) << "NumelOffset: " - << string::join_strings(infos, ",", [](const ParamGradInfo &info) { - return info.numel_offset; - }); - VLOG(10) << "start_size = " << start_size << " , end_size = " << end_size; - - if (infos.empty()) { - PADDLE_ENFORCE_EQ( - start_size, - 0, - platform::errors::InvalidArgument("start_size should be 0.")); - PADDLE_ENFORCE_EQ( - end_size, - 0, - platform::errors::InvalidArgument("end_size should be 0.")); - *start_idx = 0; - *end_idx = 0; - *start_numel_offset = 0; - *end_numel_offset = 0; - return; - } - - PADDLE_ENFORCE_LT(start_size, - end_size, - platform::errors::InvalidArgument( - "start_size should be less than end_size.")); - size_t n = infos.size(); - ParamGradInfoNumelOffsetCompFunctor comp; - auto i = static_cast( - std::lower_bound(infos.begin(), infos.end(), start_size, comp) - - infos.begin()); - if (i == n || infos[i].numel_offset != start_size) { - PADDLE_ENFORCE_GT( - i, - 0, - platform::errors::InvalidArgument( - "Cannot find suitable sharding which is between [%d, %d)", - start_size, - end_size)); - --i; - } - PADDLE_ENFORCE_LT( - i, - n, - platform::errors::InvalidArgument( - "Cannot find suitable sharding which is between [%d, %d)", - start_size, - end_size)); - *start_idx = i; - *start_numel_offset = start_size - infos[i].numel_offset; - auto j = static_cast( - std::lower_bound(infos.begin(), infos.end(), end_size, comp) - - infos.begin()); - *end_idx = j - 1; - *end_numel_offset = end_size - infos[j - 1].numel_offset; - PADDLE_ENFORCE_GT(*end_numel_offset, - 0, - platform::errors::InvalidArgument( - "Internal error when sharding, this may be a bug " - "caused by empty parameter.")); - VLOG(10) << "Sharding [start_size=" << start_size << ", end_size=" << end_size - << "): " << (*start_idx) << ":" << (*start_numel_offset) << " -> " - << (*end_idx) << ":" << (*end_numel_offset); -} - -static size_t FillAlignmentPaddingInfo(std::vector *infos, - size_t alignment, - size_t nranks, - phi::DataType dtype) { - auto sizeof_dtype = phi::SizeOf(dtype); - PADDLE_ENFORCE_EQ( - alignment % sizeof_dtype, - 0, - platform::errors::InvalidArgument( - "The attr(alignment) should be exactly divided by sizeof(T) %d.", - sizeof_dtype)); - alignment /= sizeof_dtype; - - size_t total_numel_sum_with_padding = 0; - size_t n = infos->size(); - for (size_t i = 0; i < n; ++i) { - auto &info = (*infos)[i]; - size_t numel_with_padding; - if (i + 1 == n) { - // the total fused numel must be a factor of alignment * nranks - numel_with_padding = - GetAlignSize(info.numel + total_numel_sum_with_padding, - alignment * nranks) - - total_numel_sum_with_padding; - } else { - numel_with_padding = GetAlignSize(info.numel, alignment); - } - info.numel_with_padding = numel_with_padding; - info.numel_offset = total_numel_sum_with_padding; - total_numel_sum_with_padding += numel_with_padding; - } - return total_numel_sum_with_padding; -} - -template -static T *TensorFillConstant(const phi::GPUContext &dev_ctx, - phi::DenseTensor *tensor, - const framework::DDim &dims, - T value) { - tensor->Resize(dims); - auto *ptr = tensor->mutable_data(dev_ctx.GetPlace()); - phi::funcs::SetConstant set_constant; - set_constant(dev_ctx, tensor, value); - return ptr; -} - -static phi::DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx, - phi::DenseTensor *origin, - phi::DenseTensor *fused_out, - size_t numel_offset) { - PADDLE_ENFORCE_EQ(origin->IsInitialized(), - true, - platform::errors::InvalidArgument( - "The tensor to be cast should be initialized.")); - - PADDLE_ENFORCE_EQ(fused_out->dtype(), - phi::DataType::FLOAT32, - platform::errors::InvalidArgument( - "The dst tensor to be cast should be FP32 tensor.")); - PADDLE_ENFORCE_EQ(origin->dtype(), - phi::DataType::FLOAT16, - platform::errors::InvalidArgument( - "The src tensor to be cast should be FP16 tensor.")); - auto *dst = fused_out->data() + numel_offset; - auto *src = origin->data(); - auto numel = origin->numel(); - LaunchCastKernel(dev_ctx, src, dst, numel); - VLOG(10) << "Cast from FP32 -> FP16, range: [" << numel_offset << ", " - << numel_offset + numel << ")" - << " , total: [0, " << fused_out->numel() << ")"; - framework::DDim fused_out_dim = fused_out->dims(); - auto fused_out_numel = fused_out->numel(); - fused_out->Resize({fused_out_numel}); - auto sliced_tensor = fused_out->Slice(numel_offset, numel + numel_offset); - fused_out->Resize(fused_out_dim); - return sliced_tensor; -} - -static phi::DenseTensor CopyAndShareBufferForInitedTensor( - phi::DenseTensor *origin, - phi::DenseTensor *fused_out, - size_t numel_offset, - gpuStream_t stream) { - PADDLE_ENFORCE_EQ( - origin->IsInitialized(), - true, - platform::errors::InvalidArgument( - "The tensor to be copied and shared data should be initialized.")); - auto dtype = fused_out->type(); - PADDLE_ENFORCE_EQ(origin->type(), - dtype, - platform::errors::InvalidArgument( - "The tensor to be copied and shared data should be " - "have the same data type.")); - auto place = fused_out->place(); - PADDLE_ENFORCE_EQ( - origin->place(), - place, - platform::errors::InvalidArgument("The tensor to be copied and shared " - "data should be have the same place.")); - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(place), - true, - platform::errors::InvalidArgument( - "The tensor to be copied and shared data should be on GPU place.")); - - auto numel = origin->numel(); - framework::DDim fused_out_dim = fused_out->dims(); - auto fused_out_numel = fused_out->numel(); - auto sliced_tensor = fused_out->Resize({fused_out_numel}) - .Slice(numel_offset, numel + numel_offset); - memory::Copy(place, - sliced_tensor.data(), - place, - origin->data(), - numel * phi::SizeOf(dtype), - stream); - origin->ShareBufferWith(sliced_tensor); - fused_out->Resize(fused_out_dim); - VLOG(10) << "Copy and share buffer, range: [" << numel_offset << ", " - << numel_offset + numel << ") , total: [0, " << fused_out->numel() - << ") , dtype = " << dtype; - return sliced_tensor; -} - -static void ShareBufferForNonInitedTensor(phi::DenseTensor *origin, - phi::DenseTensor *fused_out, - size_t numel_offset, - const framework::DDim &dims) { - PADDLE_ENFORCE_EQ( - origin->IsInitialized(), - false, - platform::errors::InvalidArgument( - "The tensor to be shared data should not be initialized.")); - - framework::DDim fused_out_dim = fused_out->dims(); - auto fused_out_numel = fused_out->numel(); - auto numel = phi::product(dims); - *origin = fused_out->Resize({fused_out_numel}) - .Slice(numel_offset, numel + numel_offset); - origin->Resize(dims); - fused_out->Resize(fused_out_dim); - VLOG(10) << "Share buffer for non-inited, range: [" << numel_offset << ", " - << numel_offset + numel << "), total: [0, " << fused_out->numel() - << ") , dtype = " << fused_out->dtype(); -} - -template -static void CopyVectorToCPUTensor(const std::vector &src, - phi::DenseTensor *dst) { - dst->Resize({static_cast(src.size())}); - T *dst_ptr = dst->mutable_data(platform::CPUPlace()); - const T *src_ptr = src.data(); - auto nbytes = src.size() * sizeof(T); - std::memcpy(dst_ptr, src_ptr, nbytes); -} - -static size_t ReorderParamGradInfoList(const std::vector &flags, - std::vector *infos) { - size_t n = infos->size(); - std::vector cur_flags; - cur_flags.reserve(n); - for (size_t i = 0; i < n; ++i) { - auto idx = (*infos)[i].idx; - cur_flags.push_back(flags[idx]); - } - - auto origin_infos = *infos; - size_t j = 0; - for (size_t i = 0; i < n; ++i) { - if (cur_flags[i]) { - (*infos)[j] = origin_infos[i]; - ++j; - } - } - size_t ret_idx = j; - - for (size_t i = 0; i < n; ++i) { - if (!cur_flags[i]) { - (*infos)[j] = origin_infos[i]; - ++j; - } - } - return ret_idx; -} - -template -static T ClipByBound(T x, T low_value, T high_value) { - if (x < low_value) return low_value; - if (x > high_value) return high_value; - return x; -} - -template -class DistributedFusedLambInitOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - VLOG(10) << "starts to run DistributedFusedLambInitOp"; - auto &dev_ctx = ctx.template device_context(); - auto place = ctx.GetPlace(); - auto stream = dev_ctx.stream(); - - // Step 1: Check Input(Param) and Output(ParamOut), Input(Grad) and - // Output(GradOut) - auto params = ctx.MultiInput("Param"); - auto grads = ctx.MultiInput("Grad"); - auto master_params = ctx.MultiOutput("MasterParamOut"); - std::vector fp32_infos, fp16_infos; - { - PADDLE_ENFORCE_EQ(params.size(), - grads.size(), - platform::errors::InvalidArgument( - "The parameter number and parameter gradient " - "number should be the same.")); - - auto params_out = ctx.MultiOutput("ParamOut"); - auto grads_out = ctx.MultiOutput("GradOut"); - PADDLE_ENFORCE_EQ( - params.size(), - params_out.size(), - platform::errors::InvalidArgument("Input(Param) and Output(ParamOut) " - "should have the same number.")); - PADDLE_ENFORCE_EQ( - grads.size(), - grads_out.size(), - platform::errors::InvalidArgument( - "Input(Grad) and Output(GradOut) should have the same number.")); - size_t n = params.size(); - VLOG(10) << "parameter number: " << n; - for (size_t i = 0; i < n; ++i) { - auto *p = params[i]; - auto *g = grads[i]; - auto *p_out = params_out[i]; - auto *g_out = grads_out[i]; - - PADDLE_ENFORCE_NOT_NULL( - p, - platform::errors::InvalidArgument( - "The %d-th parameter should not be nullptr.", i)); - PADDLE_ENFORCE_EQ(p->IsInitialized(), - true, - platform::errors::InvalidArgument( - "The %d-th parameter should be initialized.", i)); - PADDLE_ENFORCE_EQ( - p->place(), - place, - platform::errors::InvalidArgument( - "The %d-th parameter is not initialized on the right place.", - i)); - PADDLE_ENFORCE_EQ(p, - p_out, - platform::errors::InvalidArgument( - "The %d-th Input(Param) and Output(ParamOut) " - "should be the same tensor.", - i)); - - auto dtype = p->dtype(); - PADDLE_ENFORCE_NOT_NULL( - g, - platform::errors::InvalidArgument( - "The %d-th gradient should not be nullptr.", i)); - PADDLE_ENFORCE_EQ(g, - g_out, - platform::errors::InvalidArgument( - "The %d-th Input(Grad) and Output(Grad) should " - "be the same tensor.")); - auto numel = p->numel(); - PADDLE_ENFORCE_GT(numel, - 0, - platform::errors::InvalidArgument( - "The %d-th Input(Param) have no elements.")); - - void *g_data = nullptr; - if (g->IsInitialized()) { - PADDLE_ENFORCE_EQ(g->dtype(), - dtype, - platform::errors::InvalidArgument( - "The %d-th Input(Param) and Input(Grad) should " - "have the same data type %s.", - i, - dtype)); - PADDLE_ENFORCE_EQ(g->dims(), - p->dims(), - platform::errors::InvalidArgument( - "The %d-th Input(Param) and Input(Grad) should " - "have the same shape.", - i)); - g_data = g_out->data(); - } - - ParamGradInfo *info; - if (dtype == phi::DataType::FLOAT32) { - fp32_infos.emplace_back(); - info = &fp32_infos.back(); - } else if (dtype == phi::DataType::FLOAT16) { - fp16_infos.emplace_back(); - info = &fp16_infos.back(); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported data type %s.", dtype)); - } - - VLOG(10) << "Found " << dtype << " parameter " << i << " shape=[" - << p_out->dims() << "] numel=" << numel - << " grad.IsInitialized()=" - << (g_out->IsInitialized() ? "true" : "false"); - - info->param_t = p_out; - info->grad_t = g_out; - info->idx = i; - info->numel = numel; - info->numel_with_padding = 0; // not determined yet - info->numel_offset = 0; // not determined yet - } - } - const auto &apply_weight_decay = - ctx.Attr>("apply_weight_decay"); - size_t fp32_wd_end_idx = - ReorderParamGradInfoList(apply_weight_decay, &fp32_infos); - size_t fp16_wd_end_idx = - ReorderParamGradInfoList(apply_weight_decay, &fp16_infos); - - auto *param_order_t = ctx.Output("ParamOrder"); - auto param_num = fp32_infos.size() + fp16_infos.size(); - param_order_t->Resize({static_cast(param_num)}); - auto *param_order = param_order_t->mutable_data(platform::CPUPlace()); - for (size_t i = 0; i < fp32_infos.size(); ++i) { - param_order[i] = static_cast(fp32_infos[i].idx); - } - for (size_t i = 0; i < fp16_infos.size(); ++i) { - param_order[i + fp32_infos.size()] = static_cast(fp16_infos[i].idx); - } - - VLOG(10) << "Fill ParamGradInfo ends"; - - // Step 2: determine the numel_with_padding and numel_offset - auto rank = ctx.Attr("rank"); - auto nranks = ctx.Attr("nranks"); - auto alignment = ctx.Attr("alignment"); - VLOG(10) << "rank = " << rank << ", nranks = " << nranks - << " , alignment = " << alignment; - if (alignment <= 0) { - alignment = platform::GpuMinChunkSize(); - } - PADDLE_ENFORCE_GE(alignment, - 1, - platform::errors::InvalidArgument( - "The attr(alignment) should be larger than 0.")); - PADDLE_ENFORCE_EQ(alignment & (alignment - 1), - 0, - platform::errors::InvalidArgument( - "The attr(alignment) should be the power of 2.")); - PADDLE_ENFORCE_GE( - rank, - 0, - platform::errors::InvalidArgument( - "The attr(rank) should be equal to or larger than 0.")); - PADDLE_ENFORCE_LT( - rank, - nranks, - platform::errors::InvalidArgument( - "The attr(rank) should be less than the attr(nranks).")); - // NOTE: We guarantee that both fp32_numel and fp16_numel can be exactly - // divided by alignment and nranks. - auto fp32_numel = FillAlignmentPaddingInfo( - &fp32_infos, alignment, nranks, phi::DataType::FLOAT32); - VLOG(10) << "FP32 ParamGradInfo: " << string::join_strings(fp32_infos, " "); - auto fp16_numel = FillAlignmentPaddingInfo( - &fp16_infos, alignment, nranks, phi::DataType::FLOAT16); - VLOG(10) << "FP16 ParamGradInfo: " << string::join_strings(fp16_infos, " "); - auto total_numel = fp32_numel + fp16_numel; - PADDLE_ENFORCE_LT( - total_numel, - std::numeric_limits::max(), - platform::errors::InvalidArgument("Too many parameter number.")); - - auto fp32_numel_each_device = fp32_numel / nranks; - auto fp16_numel_each_device = fp16_numel / nranks; - auto numel_each_device = fp32_numel_each_device + fp16_numel_each_device; - VLOG(10) << "Fill padding ends. total_numel = " << total_numel - << ", fp32_numel = " << fp32_numel - << ", fp16_numel = " << fp16_numel - << ", fp32_numel_each_device = " << fp32_numel_each_device - << ", fp16_numel_each_device = " << fp16_numel_each_device; - - // Step 3: allocate output tensor and do initialization - float *fused_fp32_param = nullptr, *fused_fp32_grad = nullptr; - platform::float16 *fused_fp16_param = nullptr, *fused_fp16_grad = nullptr; - phi::DenseTensor *fp32_p_t = nullptr, *fp16_p_t = nullptr, - *fp32_g_t = nullptr, *fp16_g_t = nullptr; - std::vector fp16_master_params; - if (total_numel > 0) { - fp32_p_t = ctx.Output("FP32FusedParam"); - fused_fp32_param = TensorFillConstant( - dev_ctx, fp32_p_t, {static_cast(total_numel)}, 0.0f); - } - - if (fp32_numel > 0) { - fp32_g_t = ctx.Output("FP32FusedGrad"); - fused_fp32_grad = TensorFillConstant( - dev_ctx, fp32_g_t, {static_cast(fp32_numel)}, 0.0f); - } - - if (fp16_numel > 0) { - fp16_p_t = ctx.Output("FP16FusedParam"); - fused_fp16_param = TensorFillConstant( - dev_ctx, - fp16_p_t, - {static_cast(fp16_numel)}, - static_cast(0)); - - fp16_g_t = ctx.Output("FP16FusedGrad"); - fused_fp16_grad = TensorFillConstant( - dev_ctx, - fp16_g_t, - {static_cast(fp16_numel)}, - static_cast(0)); - } - VLOG(10) << "Allocate FP32FusedParam/Grad, FP16FusedParam/Grad ends"; - - // (1) For FP32FusedParam, memcpy for fp32 param and then share data, cast - // for fp16 master weight - // (2) For FP16FusedParam, memcpy and then share data - // (3) For FP32FusedGrad/FP16FusedGrad, memcpy if gradient has been inited - for (const auto &info : fp32_infos) { - auto sliced_tensor = CopyAndShareBufferForInitedTensor( - info.param_t, fp32_p_t, info.numel_offset, stream); - master_params[info.idx]->Resize(info.param_t->dims()); - master_params[info.idx]->ShareBufferWith(sliced_tensor); - PADDLE_ENFORCE_EQ(master_params[info.idx]->mutable_data(place), - sliced_tensor.data(), - platform::errors::InvalidArgument( - "Invalid master weight tensor pointer.")); - if (info.grad_t->IsInitialized()) { - CopyAndShareBufferForInitedTensor( - info.grad_t, fp32_g_t, info.numel_offset, stream); - } else { - ShareBufferForNonInitedTensor( - info.grad_t, fp32_g_t, info.numel_offset, info.param_t->dims()); - } - } - - size_t fp16_numel_offset = 0; - if (fp32_numel > 0) { - auto last_fp32_info = fp32_infos.back(); - fp16_numel_offset = - last_fp32_info.numel_offset + last_fp32_info.numel_with_padding; - } - - for (const auto &info : fp16_infos) { - auto master_weight_offset = info.numel_offset + fp16_numel_offset; - auto sliced_tensor = CastDataForInitedTensor( - dev_ctx, info.param_t, fp32_p_t, master_weight_offset); - master_params[info.idx]->Resize(info.param_t->dims()); - master_params[info.idx]->ShareBufferWith(sliced_tensor); - - CopyAndShareBufferForInitedTensor( - info.param_t, fp16_p_t, info.numel_offset, stream); - PADDLE_ENFORCE_EQ(master_params[info.idx]->mutable_data(place), - sliced_tensor.data(), - platform::errors::InvalidArgument( - "Invalid master weight tensor pointer.")); - - if (info.grad_t->IsInitialized()) { - CopyAndShareBufferForInitedTensor( - info.grad_t, fp16_g_t, info.numel_offset, stream); - } else { - ShareBufferForNonInitedTensor( - info.grad_t, fp16_g_t, info.numel_offset, info.param_t->dims()); - } - } - VLOG(10) << "Copy/share data for Param/Grad ends"; - - // Step 4: For Moment1, Moment2, Beta1Pow, Beta2Pow, just fill constant - TensorFillConstant(dev_ctx, - ctx.Output("Moment1"), - {static_cast(numel_each_device)}, - 0.0f); - TensorFillConstant(dev_ctx, - ctx.Output("Moment2"), - {static_cast(numel_each_device)}, - 0.0f); - TensorFillConstant(dev_ctx, - ctx.Output("Beta1Pow"), - {1}, - ctx.Attr("beta1")); - TensorFillConstant(dev_ctx, - ctx.Output("Beta2Pow"), - {1}, - ctx.Attr("beta2")); - VLOG(10) << "Init Moment and BetaPow ends"; - - // Step 5: Do sharding - size_t fp32_start_idx, fp32_end_idx, fp32_start_numel_offset, - fp32_end_numel_offset; - GetParamGradShardInfo(fp32_infos, - rank * fp32_numel_each_device, - (rank + 1) * fp32_numel_each_device, - &fp32_start_idx, - &fp32_end_idx, - &fp32_start_numel_offset, - &fp32_end_numel_offset); - size_t fp16_start_idx, fp16_end_idx, fp16_start_numel_offset, - fp16_end_numel_offset; - GetParamGradShardInfo(fp16_infos, - rank * fp16_numel_each_device, - (rank + 1) * fp16_numel_each_device, - &fp16_start_idx, - &fp16_end_idx, - &fp16_start_numel_offset, - &fp16_end_numel_offset); - size_t fp32_local_param_num = - fp32_numel_each_device > 0 ? fp32_end_idx - fp32_start_idx + 1 : 0; - size_t fp16_local_param_num = - fp16_numel_each_device > 0 ? fp16_end_idx - fp16_start_idx + 1 : 0; - size_t total_local_param_num = fp32_local_param_num + fp16_local_param_num; - VLOG(10) << "Found the sharding arguments"; - - auto *param_info_t = ctx.Output("ParamInfo"); - param_info_t->Resize({8}); - auto *param_info = param_info_t->mutable_data(platform::CPUPlace()); - param_info[0] = static_cast(fp32_start_idx); - param_info[1] = static_cast(fp32_local_param_num); - param_info[2] = static_cast(fp32_infos.size()); - param_info[3] = ClipByBound(fp32_wd_end_idx, - fp32_start_idx, - fp32_start_idx + fp32_local_param_num) - - static_cast(fp32_start_idx); - param_info[4] = static_cast(fp16_start_idx + fp32_infos.size()); - param_info[5] = static_cast(fp16_local_param_num); - param_info[6] = static_cast(fp16_infos.size()); - param_info[7] = ClipByBound(fp16_wd_end_idx, - fp16_start_idx, - fp16_start_idx + fp16_local_param_num) - - static_cast(fp16_start_idx); - - VLOG(10) << "Start FP32 idx: " << param_info[0]; - VLOG(10) << "Local FP32 param num: " << param_info[1]; - VLOG(10) << "Global FP32 param num: " << param_info[2]; - - VLOG(10) << "Start FP16 idx: " << param_info[4]; - VLOG(10) << "Local FP16 param num: " << param_info[5]; - VLOG(10) << "Global FP16 param num: " << param_info[6]; - - std::vector numel_offsets; - numel_offsets.reserve(params.size() + 1); - for (const auto &info : fp32_infos) { - numel_offsets.push_back(info.numel_offset); - } - for (const auto &info : fp16_infos) { - numel_offsets.push_back(info.numel_offset + fp16_numel_offset); - } - numel_offsets.push_back(fp32_numel + fp16_numel); - PADDLE_ENFORCE_EQ(numel_offsets.size(), - params.size() + 1, - platform::errors::InvalidArgument( - "The numel_offsets number must be one larger than " - "the parameter number.")); - VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets); - - std::vector fp32_partial_numel_offsets; - fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1); - fp32_partial_numel_offsets.push_back(0); - // Fill the partial_numel_offsets - for (size_t i = fp32_start_idx; i < fp32_start_idx + fp32_local_param_num; - ++i) { - size_t valid_start_n = 0; - if (i == fp32_start_idx) { - valid_start_n = fp32_start_numel_offset; - } - - size_t end_n = fp32_infos[i].numel_with_padding; - if (i + 1 == fp32_start_idx + fp32_local_param_num) { - end_n = std::min(end_n, fp32_end_numel_offset); - } - - PADDLE_ENFORCE_NE(valid_start_n, - end_n, - platform::errors::InvalidArgument( - "Indices sharding error. This may be a bug.")); - VLOG(10) << "FP32 Partial numel = [" - << valid_start_n + fp32_infos[i].numel << "," - << end_n + fp32_infos[i].numel; - auto len = end_n - valid_start_n; - fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() + - len); - } - - std::vector fp16_partial_numel_offsets; - fp16_partial_numel_offsets.reserve(fp16_local_param_num + 1); - fp16_partial_numel_offsets.push_back(0); - for (size_t i = fp16_start_idx; i < fp16_start_idx + fp16_local_param_num; - ++i) { - size_t valid_start_n = 0; - if (i == fp16_start_idx) { - valid_start_n = fp16_start_numel_offset; - } - - size_t end_n = fp16_infos[i].numel_with_padding; - if (i + 1 == fp16_start_idx + fp16_local_param_num) { - end_n = std::min(end_n, fp16_end_numel_offset); - } - - PADDLE_ENFORCE_NE(valid_start_n, - end_n, - platform::errors::InvalidArgument( - "Indices sharding error. This may be a bug.")); - auto len = end_n - valid_start_n; - fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() + - len); - } - - CopyVectorToCPUTensor(numel_offsets, - ctx.Output("FusedParamOffsets")); - CopyVectorToCPUTensor( - fp32_partial_numel_offsets, - ctx.Output("FP32ShardFusedParamOffsets")); - CopyVectorToCPUTensor( - fp16_partial_numel_offsets, - ctx.Output("FP16ShardFusedParamOffsets")); - - auto *global_scale = ctx.Output("GlobalScale"); - if (!global_scale->IsInitialized()) { - TensorFillConstant(dev_ctx, global_scale, {1}, 1.0f); - } - VLOG(10) << "Init global scale ends"; - - TensorFillConstant(dev_ctx, - ctx.Output("Step"), - {1}, - static_cast(0)); - - dev_ctx.Wait(); - VLOG(10) << "Wait for H2D copy"; - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb_init, - GPU, - ALL_LAYOUT, - ops::DistributedFusedLambInitOpKernel, - float) {} diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h deleted file mode 100644 index 7c314cd9e379080090a621f5a9269c63946d056c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { - -template -class DistributedFusedLambInitOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "The distributed_fused_lamb_init operator does not support CPU yet.")); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/distributed_fused_lamb_init_kernel.h b/paddle/phi/kernels/distributed_fused_lamb_init_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..182c79ab80319f8daf72d5f5d8a4db3cfab38b20 --- /dev/null +++ b/paddle/phi/kernels/distributed_fused_lamb_init_kernel.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DistributedFusedLambInitOpKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& grad, + float beta1, + float beta2, + const std::vector& apply_weight_decay, + int alignment, + int rank, + int nranks, + DenseTensor* fp32_fused_param, + DenseTensor* fp32_fused_grad, + DenseTensor* fp16_fused_param, + DenseTensor* fp16_fused_grad, + DenseTensor* moment1, + DenseTensor* moment2, + DenseTensor* beta1_pow, + DenseTensor* beta2_pow, + DenseTensor* fused_param_offsets, + DenseTensor* fp32_shard_fused_param_offsets, + DenseTensor* fp16_shard_fused_param_offsets, + DenseTensor* param_info, + DenseTensor* param_order, + std::vector param_out, + std::vector master_param_out, + std::vector grad_out, + DenseTensor* global_scale, + DenseTensor* step); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc b/paddle/phi/kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..3cb37ccf2ed89daec5648d3be2cef102ff202871 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/distributed_fused_lamb_init_kernel.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void DistributedFusedLambInitOpKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& grad, + float beta1, + float beta2, + const std::vector& apply_weight_decay, + int alignment, + int rank, + int nranks, + DenseTensor* fp32_fused_param, + DenseTensor* fp32_fused_grad, + DenseTensor* fp16_fused_param, + DenseTensor* fp16_fused_grad, + DenseTensor* moment1, + DenseTensor* moment2, + DenseTensor* beta1_pow, + DenseTensor* beta2_pow, + DenseTensor* fused_param_offsets, + DenseTensor* fp32_shard_fused_param_offsets, + DenseTensor* fp16_shard_fused_param_offsets, + DenseTensor* param_info, + DenseTensor* param_order, + std::vector param_out, + std::vector master_param_out, + std::vector grad_out, + DenseTensor* global_scale, + DenseTensor* step) { + PADDLE_THROW(phi::errors::Unavailable( + "Do not support expert count op for cpu kernel now.")); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(distributed_fused_lamb_init, + CPU, + ALL_LAYOUT, + phi::fusion::DistributedFusedLambInitOpKernel, + float) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT16); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT16); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(8).SetDataType(phi::DataType::INT32); + kernel->OutputAt(9).SetDataType(phi::DataType::INT32); + kernel->OutputAt(10).SetDataType(phi::DataType::INT32); + kernel->OutputAt(11).SetDataType(phi::DataType::INT32); + kernel->OutputAt(12).SetDataType(phi::DataType::INT32); + kernel->OutputAt(13).SetDataType(kernel_key.dtype()); + kernel->OutputAt(14).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(15).SetDataType(kernel_key.dtype()); + kernel->OutputAt(16).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(17).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/fluid/operators/optimizers/cast_with_ptr.h b/paddle/phi/kernels/fusion/gpu/cast_with_ptr.h similarity index 74% rename from paddle/fluid/operators/optimizers/cast_with_ptr.h rename to paddle/phi/kernels/fusion/gpu/cast_with_ptr.h index 205eb2853a3419fd6ae2816f39fc1001d7e94895..5ae8aed256ccddbc234f0139396ade42f446ff45 100644 --- a/paddle/fluid/operators/optimizers/cast_with_ptr.h +++ b/paddle/phi/kernels/fusion/gpu/cast_with_ptr.h @@ -14,28 +14,24 @@ #pragma once -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" -namespace paddle { -namespace operators { -namespace details { +namespace phi { template struct CastFunctor { HOSTDEVICE OutT operator()(InT x) const { return static_cast(x); } }; - template static void VecCastKernel(const phi::GPUContext &ctx, const InT *x, OutT *y, size_t n) { - auto config = platform::GetGpuLaunchConfig1D(ctx, n, VecSize); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, n, VecSize); auto block = config.GetGridSize(); auto thread = config.GetBlockSize(); auto main_offset = n / (VecSize * thread) * VecSize * thread; @@ -50,8 +46,6 @@ static void VecCastKernel(const phi::GPUContext &ctx, in_arr, out_arr, n, main_offset, VecSize, FunctorT()); } -} // namespace details - template static void LaunchCastKernel(const phi::GPUContext &ctx, const InT *x, @@ -61,20 +55,19 @@ static void LaunchCastKernel(const phi::GPUContext &ctx, PADDLE_ENFORCE_NE( static_cast(x), static_cast(y), - platform::errors::InvalidArgument("Inplace cast is not supported yet.")); + errors::InvalidArgument("Inplace cast is not supported yet.")); int vec_size = std::min(phi::GetVectorizedSize(x), phi::GetVectorizedSize(y)); switch (vec_size) { case 4: - return details::VecCastKernel(ctx, x, y, n); + return VecCastKernel(ctx, x, y, n); case 2: - return details::VecCastKernel(ctx, x, y, n); + return VecCastKernel(ctx, x, y, n); case 1: - return details::VecCastKernel(ctx, x, y, n); + return VecCastKernel(ctx, x, y, n); default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The vectorized size must be 1, 2 or 4.")); + PADDLE_THROW( + errors::InvalidArgument("The vectorized size must be 1, 2 or 4.")); } } -} // namespace operators -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/distributed_fused_lamb_init_kernel.cu b/paddle/phi/kernels/fusion/gpu/distributed_fused_lamb_init_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ae7f0682bc75b64bcf25f6aa902820de1196b6d --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/distributed_fused_lamb_init_kernel.cu @@ -0,0 +1,804 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/distributed_fused_lamb_init_kernel.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/algorithm.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/tensor_to_string.h" +#include "paddle/phi/kernels/fusion/gpu/cast_with_ptr.h" + +namespace phi { +namespace fusion { + +using phi::funcs::FlattenToString; +using phi::funcs::ToVector; + +struct ParamGradInfo { + DenseTensor *param_t{nullptr}; + DenseTensor *grad_t{nullptr}; + size_t idx{0}; + size_t numel{0}; + size_t numel_with_padding{0}; + size_t numel_offset{0}; +}; + +static std::ostream &operator<<(std::ostream &os, const ParamGradInfo &info) { + return os << "{Param(" << info.param_t << "),Grad(" << info.grad_t << "),idx(" + << info.idx << "),numel(" << info.numel << "),numel_with_padding(" + << info.numel_with_padding << "),numel_offset(" << info.numel_offset + << "),padding(" << info.numel_offset + info.numel_with_padding + << "-" << info.numel_offset + info.numel << "=" + << info.numel_with_padding - info.numel << ")}"; +} + +struct ParamGradInfoNumelOffsetCompFunctor { + bool operator()(const ParamGradInfo &x, const ParamGradInfo &y) const { + return x.numel_offset < y.numel_offset; + } + + bool operator()(const ParamGradInfo &x, size_t y) const { + return x.numel_offset < y; + } + + bool operator()(size_t x, const ParamGradInfo &y) const { + return x < y.numel_offset; + } + + bool operator()(size_t x, size_t y) const { return x < y; } +}; + +static size_t GetAlignSize(size_t n, size_t alignment) { + auto remainder = n % alignment; + return remainder == 0 ? n : n + alignment - remainder; +} + +// Shard the ParamGradInfo list by the numel size [start_size, end_size) +// The final results should be: +// +// start_size = sum(infos[0:i].numel_with_padding) + start_numel_offset, where +// start_numel_offset <= infos[i].numel_with_padding +// +// end_size = sum(infos[0:j].numel_with_padding) + end_numel_offset, where +// end_numel_offset <= infos[j].numel_with_padding +static void GetParamGradShardInfo(const std::vector &infos, + size_t start_size, + size_t end_size, + size_t *start_idx, + size_t *end_idx, + size_t *start_numel_offset, + size_t *end_numel_offset) { + VLOG(10) << "NumelOffset: " + << paddle::string::join_strings( + infos, ",", [](const ParamGradInfo &info) { + return info.numel_offset; + }); + VLOG(10) << "start_size = " << start_size << " , end_size = " << end_size; + + if (infos.empty()) { + PADDLE_ENFORCE_EQ( + start_size, 0, errors::InvalidArgument("start_size should be 0.")); + PADDLE_ENFORCE_EQ( + end_size, 0, errors::InvalidArgument("end_size should be 0.")); + *start_idx = 0; + *end_idx = 0; + *start_numel_offset = 0; + *end_numel_offset = 0; + return; + } + + PADDLE_ENFORCE_LT( + start_size, + end_size, + errors::InvalidArgument("start_size should be less than end_size.")); + size_t n = infos.size(); + ParamGradInfoNumelOffsetCompFunctor comp; + auto i = static_cast( + std::lower_bound(infos.begin(), infos.end(), start_size, comp) - + infos.begin()); + if (i == n || infos[i].numel_offset != start_size) { + PADDLE_ENFORCE_GT( + i, + 0, + errors::InvalidArgument( + "Cannot find suitable sharding which is between [%d, %d)", + start_size, + end_size)); + --i; + } + PADDLE_ENFORCE_LT( + i, + n, + errors::InvalidArgument( + "Cannot find suitable sharding which is between [%d, %d)", + start_size, + end_size)); + *start_idx = i; + *start_numel_offset = start_size - infos[i].numel_offset; + auto j = static_cast( + std::lower_bound(infos.begin(), infos.end(), end_size, comp) - + infos.begin()); + *end_idx = j - 1; + *end_numel_offset = end_size - infos[j - 1].numel_offset; + PADDLE_ENFORCE_GT( + *end_numel_offset, + 0, + errors::InvalidArgument("Internal error when sharding, this may be a bug " + "caused by empty parameter.")); + VLOG(10) << "Sharding [start_size=" << start_size << ", end_size=" << end_size + << "): " << (*start_idx) << ":" << (*start_numel_offset) << " -> " + << (*end_idx) << ":" << (*end_numel_offset); +} + +static size_t FillAlignmentPaddingInfo(std::vector *infos, + size_t alignment, + size_t nranks, + phi::DataType dtype) { + auto sizeof_dtype = phi::SizeOf(dtype); + PADDLE_ENFORCE_EQ( + alignment % sizeof_dtype, + 0, + errors::InvalidArgument( + "The attr(alignment) should be exactly divided by sizeof(T) %d.", + sizeof_dtype)); + alignment /= sizeof_dtype; + + size_t total_numel_sum_with_padding = 0; + size_t n = infos->size(); + for (size_t i = 0; i < n; ++i) { + auto &info = (*infos)[i]; + size_t numel_with_padding; + if (i + 1 == n) { + // the total fused numel must be a factor of alignment * nranks + numel_with_padding = + GetAlignSize(info.numel + total_numel_sum_with_padding, + alignment * nranks) - + total_numel_sum_with_padding; + } else { + numel_with_padding = GetAlignSize(info.numel, alignment); + } + info.numel_with_padding = numel_with_padding; + info.numel_offset = total_numel_sum_with_padding; + total_numel_sum_with_padding += numel_with_padding; + } + return total_numel_sum_with_padding; +} + +template +static T *TensorFillConstant(const phi::GPUContext &dev_ctx, + DenseTensor *tensor, + const DDim &dims, + T value) { + tensor->Resize(dims); + auto *ptr = dev_ctx.template Alloc(tensor); + phi::funcs::SetConstant set_constant; + set_constant(dev_ctx, tensor, value); + return ptr; +} + +static DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx, + DenseTensor *origin, + DenseTensor *fused_out, + size_t numel_offset) { + PADDLE_ENFORCE_EQ( + origin->IsInitialized(), + true, + errors::InvalidArgument("The tensor to be cast should be initialized.")); + + PADDLE_ENFORCE_EQ(fused_out->dtype(), + phi::DataType::FLOAT32, + errors::InvalidArgument( + "The dst tensor to be cast should be FP32 tensor.")); + PADDLE_ENFORCE_EQ(origin->dtype(), + phi::DataType::FLOAT16, + errors::InvalidArgument( + "The src tensor to be cast should be FP16 tensor.")); + auto *dst = fused_out->data() + numel_offset; + auto *src = origin->data(); + auto numel = origin->numel(); + LaunchCastKernel(dev_ctx, src, dst, numel); + VLOG(10) << "Cast from FP32 -> FP16, range: [" << numel_offset << ", " + << numel_offset + numel << ")" + << " , total: [0, " << fused_out->numel() << ")"; + DDim fused_out_dim = fused_out->dims(); + auto fused_out_numel = fused_out->numel(); + fused_out->Resize({fused_out_numel}); + auto sliced_tensor = fused_out->Slice(numel_offset, numel + numel_offset); + fused_out->Resize(fused_out_dim); + return sliced_tensor; +} + +static DenseTensor CopyAndShareBufferForInitedTensor( + const phi::GPUContext &dev_ctx, + DenseTensor *origin, + DenseTensor *fused_out, + size_t numel_offset) { + PADDLE_ENFORCE_EQ( + origin->IsInitialized(), + true, + errors::InvalidArgument( + "The tensor to be copied and shared data should be initialized.")); + auto dtype = fused_out->type(); + PADDLE_ENFORCE_EQ(origin->type(), + dtype, + errors::InvalidArgument( + "The tensor to be copied and shared data should be " + "have the same data type.")); + auto place = fused_out->place(); + PADDLE_ENFORCE_EQ( + origin->place(), + place, + errors::InvalidArgument("The tensor to be copied and shared " + "data should be have the same place.")); + PADDLE_ENFORCE_EQ( + dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU, + true, + errors::InvalidArgument( + "The tensor to be copied and shared data should be on GPU place.")); + + auto numel = origin->numel(); + DDim fused_out_dim = fused_out->dims(); + auto fused_out_numel = fused_out->numel(); + auto sliced_tensor = fused_out->Resize({fused_out_numel}) + .Slice(numel_offset, numel + numel_offset); + phi::Copy(dev_ctx, *origin, dev_ctx.GetPlace(), false, &sliced_tensor); + origin->ShareBufferWith(sliced_tensor); + fused_out->Resize(fused_out_dim); + VLOG(10) << "Copy and share buffer, range: [" << numel_offset << ", " + << numel_offset + numel << ") , total: [0, " << fused_out->numel() + << ") , dtype = " << dtype; + return sliced_tensor; +} + +static void ShareBufferForNonInitedTensor(DenseTensor *origin, + DenseTensor *fused_out, + size_t numel_offset, + const DDim &dims) { + PADDLE_ENFORCE_EQ( + origin->IsInitialized(), + false, + errors::InvalidArgument( + "The tensor to be shared data should not be initialized.")); + + DDim fused_out_dim = fused_out->dims(); + auto fused_out_numel = fused_out->numel(); + auto numel = phi::product(dims); + *origin = fused_out->Resize({fused_out_numel}) + .Slice(numel_offset, numel + numel_offset); + origin->Resize(dims); + fused_out->Resize(fused_out_dim); + VLOG(10) << "Share buffer for non-inited, range: [" << numel_offset << ", " + << numel_offset + numel << "), total: [0, " << fused_out->numel() + << ") , dtype = " << fused_out->dtype(); +} + +template +static void CopyVectorToCPUTensor(const phi::GPUContext &dev_ctx, + const std::vector &src, + DenseTensor *dst) { + dst->Resize({static_cast(src.size())}); + T *dst_ptr = dev_ctx.template HostAlloc(dst); + const T *src_ptr = src.data(); + auto nbytes = src.size() * sizeof(T); + std::memcpy(dst_ptr, src_ptr, nbytes); +} + +static size_t ReorderParamGradInfoList(const std::vector &flags, + std::vector *infos) { + size_t n = infos->size(); + std::vector cur_flags; + cur_flags.reserve(n); + for (size_t i = 0; i < n; ++i) { + auto idx = (*infos)[i].idx; + cur_flags.push_back(flags[idx]); + } + + auto origin_infos = *infos; + size_t j = 0; + for (size_t i = 0; i < n; ++i) { + if (cur_flags[i]) { + (*infos)[j] = origin_infos[i]; + ++j; + } + } + size_t ret_idx = j; + + for (size_t i = 0; i < n; ++i) { + if (!cur_flags[i]) { + (*infos)[j] = origin_infos[i]; + ++j; + } + } + return ret_idx; +} + +template +static T ClipByBound(T x, T low_value, T high_value) { + if (x < low_value) return low_value; + if (x > high_value) return high_value; + return x; +} + +template +void DistributedFusedLambInitOpKernel( + const Context &dev_ctx, + const std::vector ¶m, + const std::vector &grad, + float beta1, + float beta2, + const std::vector &apply_weight_decay, + int alignment, + int rank, + int nranks, + DenseTensor *fp32_fused_param, + DenseTensor *fp32_fused_grad, + DenseTensor *fp16_fused_param, + DenseTensor *fp16_fused_grad, + DenseTensor *moment1, + DenseTensor *moment2, + DenseTensor *beta1_pow, + DenseTensor *beta2_pow, + DenseTensor *fused_param_offsets, + DenseTensor *fp32_shard_fused_param_offsets, + DenseTensor *fp16_shard_fused_param_offsets, + DenseTensor *param_info, + DenseTensor *param_order, + std::vector param_out, + std::vector master_param_out, + std::vector grad_out, + DenseTensor *global_scale, + DenseTensor *step) { + VLOG(10) << "starts to run DistributedFusedLambInitOp"; + auto place = dev_ctx.GetPlace(); + auto stream = dev_ctx.stream(); + + // Step 1: Check Input(Param) and Output(ParamOut), Input(Grad) and + // Output(GradOut) + std::vector fp32_infos, fp16_infos; + { + PADDLE_ENFORCE_EQ( + param.size(), + grad.size(), + errors::InvalidArgument("The parameter number and parameter gradient " + "number should be the same.")); + + PADDLE_ENFORCE_EQ( + param.size(), + param_out.size(), + errors::InvalidArgument("Input(Param) and Output(ParamOut) " + "should have the same number.")); + PADDLE_ENFORCE_EQ( + grad.size(), + grad_out.size(), + errors::InvalidArgument( + "Input(Grad) and Output(GradOut) should have the same number.")); + size_t n = param.size(); + VLOG(10) << "parameter number: " << n; + for (size_t i = 0; i < n; ++i) { + auto *p = param[i]; + auto *g = grad[i]; + auto *p_out = param_out[i]; + auto *g_out = grad_out[i]; + + PADDLE_ENFORCE_NOT_NULL( + p, + errors::InvalidArgument("The %d-th parameter should not be nullptr.", + i)); + PADDLE_ENFORCE_EQ(p->IsInitialized(), + true, + errors::InvalidArgument( + "The %d-th parameter should be initialized.", i)); + PADDLE_ENFORCE_EQ( + p->place(), + place, + errors::InvalidArgument( + "The %d-th parameter is not initialized on the right place.", i)); + PADDLE_ENFORCE_EQ( + p, + p_out, + errors::InvalidArgument("The %d-th Input(Param) and Output(ParamOut) " + "should be the same tensor.", + i)); + + auto dtype = p->dtype(); + PADDLE_ENFORCE_NOT_NULL( + g, + errors::InvalidArgument("The %d-th gradient should not be nullptr.", + i)); + PADDLE_ENFORCE_EQ(g, + g_out, + errors::InvalidArgument( + "The %d-th Input(Grad) and Output(Grad) should " + "be the same tensor.")); + auto numel = p->numel(); + PADDLE_ENFORCE_GT( + numel, + 0, + errors::InvalidArgument("The %d-th Input(Param) have no elements.")); + + void *g_data = nullptr; + if (g->IsInitialized()) { + PADDLE_ENFORCE_EQ(g->dtype(), + dtype, + errors::InvalidArgument( + "The %d-th Input(Param) and Input(Grad) should " + "have the same data type %s.", + i, + dtype)); + PADDLE_ENFORCE_EQ(g->dims(), + p->dims(), + errors::InvalidArgument( + "The %d-th Input(Param) and Input(Grad) should " + "have the same shape.", + i)); + g_data = g_out->data(); + } + + ParamGradInfo *info; + if (dtype == phi::DataType::FLOAT32) { + fp32_infos.emplace_back(); + info = &fp32_infos.back(); + } else if (dtype == phi::DataType::FLOAT16) { + fp16_infos.emplace_back(); + info = &fp16_infos.back(); + } else { + PADDLE_THROW( + errors::InvalidArgument("Unsupported data type %s.", dtype)); + } + + VLOG(10) << "Found " << dtype << " parameter " << i << " shape=[" + << p_out->dims() << "] numel=" << numel + << " grad.IsInitialized()=" + << (g_out->IsInitialized() ? "true" : "false"); + + info->param_t = p_out; + info->grad_t = g_out; + info->idx = i; + info->numel = numel; + info->numel_with_padding = 0; // not determined yet + info->numel_offset = 0; // not determined yet + } + } + + size_t fp32_wd_end_idx = + ReorderParamGradInfoList(apply_weight_decay, &fp32_infos); + size_t fp16_wd_end_idx = + ReorderParamGradInfoList(apply_weight_decay, &fp16_infos); + + auto param_num = fp32_infos.size() + fp16_infos.size(); + param_order->Resize({static_cast(param_num)}); + auto *param_order_t = dev_ctx.template HostAlloc(param_order); + for (size_t i = 0; i < fp32_infos.size(); ++i) { + param_order_t[i] = static_cast(fp32_infos[i].idx); + } + for (size_t i = 0; i < fp16_infos.size(); ++i) { + param_order_t[i + fp32_infos.size()] = static_cast(fp16_infos[i].idx); + } + + VLOG(10) << "Fill ParamGradInfo ends"; + + // Step 2: determine the numel_with_padding and numel_offset + VLOG(10) << "rank = " << rank << ", nranks = " << nranks + << " , alignment = " << alignment; + if (alignment <= 0) { + alignment = phi::backends::gpu::GpuMinChunkSize(); + } + PADDLE_ENFORCE_GE( + alignment, + 1, + errors::InvalidArgument("The attr(alignment) should be larger than 0.")); + PADDLE_ENFORCE_EQ( + alignment & (alignment - 1), + 0, + errors::InvalidArgument("The attr(alignment) should be the power of 2.")); + PADDLE_ENFORCE_GE(rank, + 0, + errors::InvalidArgument( + "The attr(rank) should be equal to or larger than 0.")); + PADDLE_ENFORCE_LT( + rank, + nranks, + errors::InvalidArgument( + "The attr(rank) should be less than the attr(nranks).")); + // NOTE: We guarantee that both fp32_numel and fp16_numel can be exactly + // divided by alignment and nranks. + auto fp32_numel = FillAlignmentPaddingInfo( + &fp32_infos, alignment, nranks, phi::DataType::FLOAT32); + VLOG(10) << "FP32 ParamGradInfo: " + << paddle::string::join_strings(fp32_infos, " "); + auto fp16_numel = FillAlignmentPaddingInfo( + &fp16_infos, alignment, nranks, phi::DataType::FLOAT16); + VLOG(10) << "FP16 ParamGradInfo: " + << paddle::string::join_strings(fp16_infos, " "); + auto total_numel = fp32_numel + fp16_numel; + PADDLE_ENFORCE_LT(total_numel, + std::numeric_limits::max(), + errors::InvalidArgument("Too many parameter number.")); + + auto fp32_numel_each_device = fp32_numel / nranks; + auto fp16_numel_each_device = fp16_numel / nranks; + auto numel_each_device = fp32_numel_each_device + fp16_numel_each_device; + VLOG(10) << "Fill padding ends. total_numel = " << total_numel + << ", fp32_numel = " << fp32_numel << ", fp16_numel = " << fp16_numel + << ", fp32_numel_each_device = " << fp32_numel_each_device + << ", fp16_numel_each_device = " << fp16_numel_each_device; + + // Step 3: allocate output tensor and do initialization + float *fused_fp32_param = nullptr, *fused_fp32_grad = nullptr; + dtype::float16 *fused_fp16_param = nullptr, *fused_fp16_grad = nullptr; + DenseTensor *fp32_p_t = nullptr, *fp16_p_t = nullptr, *fp32_g_t = nullptr, + *fp16_g_t = nullptr; + std::vector fp16_master_params; + if (total_numel > 0) { + fp32_p_t = fp32_fused_param; + fused_fp32_param = TensorFillConstant( + dev_ctx, fp32_p_t, {static_cast(total_numel)}, 0.0f); + } + + if (fp32_numel > 0) { + fp32_g_t = fp32_fused_grad; + fused_fp32_grad = TensorFillConstant( + dev_ctx, fp32_g_t, {static_cast(fp32_numel)}, 0.0f); + } + + if (fp16_numel > 0) { + fp16_p_t = fp16_fused_param; + fused_fp16_param = + TensorFillConstant(dev_ctx, + fp16_p_t, + {static_cast(fp16_numel)}, + static_cast(0)); + + fp16_g_t = fp16_fused_grad; + fused_fp16_grad = + TensorFillConstant(dev_ctx, + fp16_g_t, + {static_cast(fp16_numel)}, + static_cast(0)); + } + VLOG(10) << "Allocate FP32FusedParam/Grad, FP16FusedParam/Grad ends"; + + // (1) For FP32FusedParam, memcpy for fp32 param and then share data, cast + // for fp16 master weight + // (2) For FP16FusedParam, memcpy and then share data + // (3) For FP32FusedGrad/FP16FusedGrad, memcpy if gradient has been inited + for (const auto &info : fp32_infos) { + auto sliced_tensor = CopyAndShareBufferForInitedTensor( + dev_ctx, info.param_t, fp32_p_t, info.numel_offset); + master_param_out[info.idx]->Resize(info.param_t->dims()); + master_param_out[info.idx]->ShareBufferWith(sliced_tensor); + float *master_param_tmp = + dev_ctx.template Alloc(master_param_out[info.idx]); + float *sliced_tensor_tmp = reinterpret_cast(sliced_tensor.data()); + PADDLE_ENFORCE_EQ( + master_param_tmp, + sliced_tensor_tmp, + errors::InvalidArgument("Invalid master weight tensor pointer.")); + + if (info.grad_t->IsInitialized()) { + CopyAndShareBufferForInitedTensor( + dev_ctx, info.grad_t, fp32_g_t, info.numel_offset); + } else { + ShareBufferForNonInitedTensor( + info.grad_t, fp32_g_t, info.numel_offset, info.param_t->dims()); + } + } + + size_t fp16_numel_offset = 0; + if (fp32_numel > 0) { + auto last_fp32_info = fp32_infos.back(); + fp16_numel_offset = + last_fp32_info.numel_offset + last_fp32_info.numel_with_padding; + } + + for (const auto &info : fp16_infos) { + auto master_weight_offset = info.numel_offset + fp16_numel_offset; + auto sliced_tensor = CastDataForInitedTensor( + dev_ctx, info.param_t, fp32_p_t, master_weight_offset); + master_param_out[info.idx]->Resize(info.param_t->dims()); + master_param_out[info.idx]->ShareBufferWith(sliced_tensor); + + CopyAndShareBufferForInitedTensor( + dev_ctx, info.param_t, fp16_p_t, info.numel_offset); + float *master_param_tmp = + dev_ctx.template Alloc(master_param_out[info.idx]); + float *sliced_tensor_tmp = reinterpret_cast(sliced_tensor.data()); + PADDLE_ENFORCE_EQ( + master_param_tmp, + sliced_tensor_tmp, + errors::InvalidArgument("Invalid master weight tensor pointer.")); + + if (info.grad_t->IsInitialized()) { + CopyAndShareBufferForInitedTensor( + dev_ctx, info.grad_t, fp16_g_t, info.numel_offset); + } else { + ShareBufferForNonInitedTensor( + info.grad_t, fp16_g_t, info.numel_offset, info.param_t->dims()); + } + } + VLOG(10) << "Copy/share data for Param/Grad ends"; + + // Step 4: For Moment1, Moment2, Beta1Pow, Beta2Pow, just fill constant + TensorFillConstant( + dev_ctx, moment1, {static_cast(numel_each_device)}, 0.0f); + TensorFillConstant( + dev_ctx, moment2, {static_cast(numel_each_device)}, 0.0f); + TensorFillConstant(dev_ctx, beta1_pow, {1}, beta1); + TensorFillConstant(dev_ctx, beta2_pow, {1}, beta2); + VLOG(10) << "Init Moment and BetaPow ends"; + + // Step 5: Do sharding + size_t fp32_start_idx, fp32_end_idx, fp32_start_numel_offset, + fp32_end_numel_offset; + GetParamGradShardInfo(fp32_infos, + rank * fp32_numel_each_device, + (rank + 1) * fp32_numel_each_device, + &fp32_start_idx, + &fp32_end_idx, + &fp32_start_numel_offset, + &fp32_end_numel_offset); + size_t fp16_start_idx, fp16_end_idx, fp16_start_numel_offset, + fp16_end_numel_offset; + GetParamGradShardInfo(fp16_infos, + rank * fp16_numel_each_device, + (rank + 1) * fp16_numel_each_device, + &fp16_start_idx, + &fp16_end_idx, + &fp16_start_numel_offset, + &fp16_end_numel_offset); + size_t fp32_local_param_num = + fp32_numel_each_device > 0 ? fp32_end_idx - fp32_start_idx + 1 : 0; + size_t fp16_local_param_num = + fp16_numel_each_device > 0 ? fp16_end_idx - fp16_start_idx + 1 : 0; + size_t total_local_param_num = fp32_local_param_num + fp16_local_param_num; + VLOG(10) << "Found the sharding arguments"; + + param_info->Resize({8}); + auto *param_info_t = dev_ctx.template HostAlloc(param_info); + param_info_t[0] = static_cast(fp32_start_idx); + param_info_t[1] = static_cast(fp32_local_param_num); + param_info_t[2] = static_cast(fp32_infos.size()); + param_info_t[3] = ClipByBound(fp32_wd_end_idx, + fp32_start_idx, + fp32_start_idx + fp32_local_param_num) - + static_cast(fp32_start_idx); + param_info_t[4] = static_cast(fp16_start_idx + fp32_infos.size()); + param_info_t[5] = static_cast(fp16_local_param_num); + param_info_t[6] = static_cast(fp16_infos.size()); + param_info_t[7] = ClipByBound(fp16_wd_end_idx, + fp16_start_idx, + fp16_start_idx + fp16_local_param_num) - + static_cast(fp16_start_idx); + + VLOG(10) << "Start FP32 idx: " << param_info_t[0]; + VLOG(10) << "Local FP32 param num: " << param_info_t[1]; + VLOG(10) << "Global FP32 param num: " << param_info_t[2]; + + VLOG(10) << "Start FP16 idx: " << param_info_t[4]; + VLOG(10) << "Local FP16 param num: " << param_info_t[5]; + VLOG(10) << "Global FP16 param num: " << param_info_t[6]; + + std::vector numel_offsets; + numel_offsets.reserve(param.size() + 1); + for (const auto &info : fp32_infos) { + numel_offsets.push_back(info.numel_offset); + } + for (const auto &info : fp16_infos) { + numel_offsets.push_back(info.numel_offset + fp16_numel_offset); + } + numel_offsets.push_back(fp32_numel + fp16_numel); + PADDLE_ENFORCE_EQ(numel_offsets.size(), + param.size() + 1, + errors::InvalidArgument( + "The numel_offsets number must be one larger than " + "the parameter number.")); + VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets); + + std::vector fp32_partial_numel_offsets; + fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1); + fp32_partial_numel_offsets.push_back(0); + // Fill the partial_numel_offsets + for (size_t i = fp32_start_idx; i < fp32_start_idx + fp32_local_param_num; + ++i) { + size_t valid_start_n = 0; + if (i == fp32_start_idx) { + valid_start_n = fp32_start_numel_offset; + } + + size_t end_n = fp32_infos[i].numel_with_padding; + if (i + 1 == fp32_start_idx + fp32_local_param_num) { + end_n = std::min(end_n, fp32_end_numel_offset); + } + + PADDLE_ENFORCE_NE( + valid_start_n, + end_n, + errors::InvalidArgument("Indices sharding error. This may be a bug.")); + VLOG(10) << "FP32 Partial numel = [" << valid_start_n + fp32_infos[i].numel + << "," << end_n + fp32_infos[i].numel; + auto len = end_n - valid_start_n; + fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() + + len); + } + + std::vector fp16_partial_numel_offsets; + fp16_partial_numel_offsets.reserve(fp16_local_param_num + 1); + fp16_partial_numel_offsets.push_back(0); + for (size_t i = fp16_start_idx; i < fp16_start_idx + fp16_local_param_num; + ++i) { + size_t valid_start_n = 0; + if (i == fp16_start_idx) { + valid_start_n = fp16_start_numel_offset; + } + + size_t end_n = fp16_infos[i].numel_with_padding; + if (i + 1 == fp16_start_idx + fp16_local_param_num) { + end_n = std::min(end_n, fp16_end_numel_offset); + } + + PADDLE_ENFORCE_NE( + valid_start_n, + end_n, + errors::InvalidArgument("Indices sharding error. This may be a bug.")); + auto len = end_n - valid_start_n; + fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() + + len); + } + + CopyVectorToCPUTensor(dev_ctx, numel_offsets, fused_param_offsets); + CopyVectorToCPUTensor( + dev_ctx, fp32_partial_numel_offsets, fp32_shard_fused_param_offsets); + CopyVectorToCPUTensor( + dev_ctx, fp16_partial_numel_offsets, fp16_shard_fused_param_offsets); + + if (!global_scale->IsInitialized()) { + TensorFillConstant(dev_ctx, global_scale, {1}, 1.0f); + } + VLOG(10) << "Init global scale ends"; + + TensorFillConstant(dev_ctx, step, {1}, static_cast(0)); + + dev_ctx.Wait(); + VLOG(10) << "Wait for H2D copy"; +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(distributed_fused_lamb_init, + GPU, + ALL_LAYOUT, + phi::fusion::DistributedFusedLambInitOpKernel, + float) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT16); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT16); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(8).SetDataType(phi::DataType::INT32); + kernel->OutputAt(9).SetDataType(phi::DataType::INT32); + kernel->OutputAt(10).SetDataType(phi::DataType::INT32); + kernel->OutputAt(11).SetDataType(phi::DataType::INT32); + kernel->OutputAt(12).SetDataType(phi::DataType::INT32); + kernel->OutputAt(13).SetDataType(kernel_key.dtype()); + kernel->OutputAt(14).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(15).SetDataType(kernel_key.dtype()); + kernel->OutputAt(16).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(17).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/ops/compat/distributed_fused_lamb_init_sig.cc b/paddle/phi/ops/compat/distributed_fused_lamb_init_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..90c64a1d2ef7c05cf80974e17ddb8ca8ecc2710b --- /dev/null +++ b/paddle/phi/ops/compat/distributed_fused_lamb_init_sig.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DistributedFusedLambInitOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature( + "distributed_fused_lamb_init", + {"Param", "Grad"}, + {"beta1", "beta2", "apply_weight_decay", "alignment", "rank", "nranks"}, + {"FP32FusedParam", + "FP32FusedGrad", + "FP16FusedParam", + "FP16FusedGrad", + "Moment1", + "Moment2", + "Beta1Pow", + "Beta2Pow", + "FusedParamOffsets", + "FP32ShardFusedParamOffsets", + "FP16ShardFusedParamOffsets", + "ParamInfo", + "ParamOrder", + "ParamOut", + "MasterParamOut", + "GradOut", + "GlobalScale", + "Step"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(distributed_fused_lamb_init, + phi::DistributedFusedLambInitOpArgumentMapping);