diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.cc b/paddle/fluid/operators/optimizers/merged_adam_op.cc index 69ca8ec3c6670a00a3c3a8083a317d952787020b..f49fc72d010304c308bbd70f0ee825f9f89631ff 100644 --- a/paddle/fluid/operators/optimizers/merged_adam_op.cc +++ b/paddle/fluid/operators/optimizers/merged_adam_op.cc @@ -10,7 +10,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/merged_adam_op.h" +#include "paddle/fluid/framework/op_registry.h" + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -21,8 +25,6 @@ class MergedAdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto param_dtype = @@ -128,13 +130,15 @@ $$ } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(merged_adam, - ops::MergedAdamOp, - ops::MergedAdamOpMaker); -REGISTER_OP_WITHOUT_GRADIENT(merged_adamw, - ops::MergedAdamOp, - ops::MergedAdamOpMaker); - -REGISTER_OP_CPU_KERNEL(merged_adam, - ops::MergedAdamOpKernel, - ops::MergedAdamOpKernel); + +DECLARE_INFER_SHAPE_FUNCTOR(merged_adam, + MergedAdamInferMetaFunctor, + PD_INFER_META(phi::MergedAdamInferMeta)); + +REGISTER_OPERATOR( + merged_adam, + ops::MergedAdamOp, + ops::MergedAdamOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + MergedAdamInferMetaFunctor); diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.cu b/paddle/fluid/operators/optimizers/merged_adam_op.cu deleted file mode 100644 index 578c9864fa42d6702c6993225effff6d7443f7e4..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/merged_adam_op.cu +++ /dev/null @@ -1,230 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/optimizers/merged_adam_op.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" - -namespace paddle { -namespace operators { - -template -__global__ void AdamKernelREG(MT beta1, - MT beta2, - MT epsilon, - MT beta1_pow_, - MT beta2_pow_, - const MT* moment1, - MT* moment1_out, - const MT* moment2, - MT* moment2_out, - const MT* lr_, - const T* grad, - const T* param, - T* param_out, - const MT* master_param, - MT* master_param_out, - int ndim) { - MT lr = *lr_; - MT beta1_pow = beta1_pow_; - MT beta2_pow = beta2_pow_; - - int id = blockIdx.x * blockDim.x + threadIdx.x; - - for (; id < ndim; id += gridDim.x * blockDim.x) { - MT p = master_param ? master_param[id] : static_cast(param[id]); - MT g = static_cast(grad[id]); - MT mom1 = static_cast(moment1[id]); - MT mom2 = static_cast(moment2[id]); - mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; - mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - - MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; - p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); - - moment1_out[id] = mom1; - moment2_out[id] = mom2; - param_out[id] = static_cast(p); - if (master_param_out) { - master_param_out[id] = p; - } - } -} - -template -__global__ void AdamKernelMEM(MT beta1, - MT beta2, - MT epsilon, - const MT* beta1_pow_, - const MT* beta2_pow_, - const MT* moment1, - MT* moment1_out, - const MT* moment2, - MT* moment2_out, - const MT* lr_, - const T* grad, - const T* param, - T* param_out, - const MT* master_param, - MT* master_param_out, - int ndim) { - MT lr = *lr_; - MT beta1_pow = *beta1_pow_; - MT beta2_pow = *beta2_pow_; - - int id = blockIdx.x * blockDim.x + threadIdx.x; - - for (; id < ndim; id += gridDim.x * blockDim.x) { - MT p = master_param ? master_param[id] : static_cast(param[id]); - MT g = static_cast(grad[id]); - MT mom1 = static_cast(moment1[id]); - MT mom2 = static_cast(moment2[id]); - mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; - mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - - MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; - p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); - - moment1_out[id] = mom1; - moment2_out[id] = mom2; - param_out[id] = static_cast(p); - if (master_param_out) { - master_param_out[id] = p; - } - } -} - -template -__global__ void UpdateBetaPow(T beta1, - T beta2, - const T* beta1_pow_, - const T* beta2_pow_, - T* beta1_pow_out, - T* beta2_pow_out) { - *beta1_pow_out = beta1 * beta1_pow_[0]; - *beta2_pow_out = beta2 * beta2_pow_[0]; -} - -template -class MergedAdamOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using MPDType = typename details::MPTypeTrait::Type; - - auto param = ctx.MultiInput("Param"); - auto grad = ctx.MultiInput("Grad"); - auto lr = ctx.MultiInput("LearningRate"); - auto mom1 = ctx.MultiInput("Moment1"); - auto mom2 = ctx.MultiInput("Moment2"); - auto beta1_pow = ctx.MultiInput("Beta1Pow"); - auto beta2_pow = ctx.MultiInput("Beta2Pow"); - - auto param_out = ctx.MultiOutput("ParamOut"); - auto mom1_out = ctx.MultiOutput("Moment1Out"); - auto mom2_out = ctx.MultiOutput("Moment2Out"); - auto beta1_pow_out = ctx.MultiOutput("Beta1PowOut"); - auto beta2_pow_out = ctx.MultiOutput("Beta2PowOut"); - - MPDType beta1 = static_cast(ctx.Attr("beta1")); - MPDType beta2 = static_cast(ctx.Attr("beta2")); - MPDType epsilon = static_cast(ctx.Attr("epsilon")); - bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); - VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; - - const bool multi_precision = ctx.Attr("multi_precision"); - auto master_param = ctx.MultiInput("MasterParam"); - auto master_param_out = - ctx.MultiOutput("MasterParamOut"); - - auto& dev_ctx = ctx.template device_context(); - - size_t param_num = param.size(); - for (size_t idx = 0; idx < param_num; idx++) { - const MPDType* master_in_data = - multi_precision ? master_param[idx]->data() : nullptr; - MPDType* master_out_data = - multi_precision - ? master_param_out[idx]->mutable_data(ctx.GetPlace()) - : nullptr; - - // update param and moment - int threads = 512; - int blocks = (param[idx]->numel() + threads - 1) / threads; - - if (beta1_pow[idx]->place() == platform::CPUPlace() && - beta2_pow[idx]->place() == platform::CPUPlace()) { - // Compute with betapow in REG - AdamKernelREG<<>>( - beta1, - beta2, - epsilon, - *beta1_pow[idx]->data(), - *beta2_pow[idx]->data(), - mom1[idx]->data(), - mom1_out[idx]->mutable_data(ctx.GetPlace()), - mom2[idx]->data(), - mom2_out[idx]->mutable_data(ctx.GetPlace()), - lr[idx]->data(), - grad[idx]->data(), - param[idx]->data(), - param_out[idx]->mutable_data(ctx.GetPlace()), - master_in_data, - master_out_data, - param[idx]->numel()); - if (!use_global_beta_pow) { - // Cpu update - beta1_pow_out[idx]->mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow[idx]->data()[0]; - beta2_pow_out[idx]->mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow[idx]->data()[0]; - } - } else { - AdamKernelMEM<<>>( - beta1, - beta2, - epsilon, - beta1_pow[idx]->data(), - beta2_pow[idx]->data(), - mom1[idx]->data(), - mom1_out[idx]->mutable_data(ctx.GetPlace()), - mom2[idx]->data(), - mom2_out[idx]->mutable_data(ctx.GetPlace()), - lr[idx]->data(), - grad[idx]->data(), - param[idx]->data(), - param_out[idx]->mutable_data(ctx.GetPlace()), - master_in_data, - master_out_data, - param[idx]->numel()); - if (!use_global_beta_pow) { - // Update with gpu - UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( - beta1, - beta2, - beta1_pow[idx]->data(), - beta2_pow[idx]->data(), - beta1_pow_out[idx]->mutable_data(ctx.GetPlace()), - beta2_pow_out[idx]->mutable_data(ctx.GetPlace())); - } - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(merged_adam, - ops::MergedAdamOpCUDAKernel, - ops::MergedAdamOpCUDAKernel, - ops::MergedAdamOpCUDAKernel); diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.h b/paddle/fluid/operators/optimizers/merged_adam_op.h deleted file mode 100644 index 3b7c8ab0286c35f007b2724db7246c4d68dcbc89..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/merged_adam_op.h +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/phi/kernels/funcs/adam_functors.h" - -namespace paddle { -namespace operators { - -namespace scatter = paddle::operators::math::scatter; - -template -class MergedAdamOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto param = ctx.MultiInput("Param"); - size_t n = param.size(); - auto grad = ctx.MultiInput("Grad"); - PADDLE_ENFORCE_EQ(n, - grad.size(), - platform::errors::InvalidArgument( - "The size of Input(Grad) must be equal to " - "Input(Param), but got the size of Input(Grad) " - "is %d, the size of Input(Param) is %d.", - grad.size(), - n)); - auto lr = ctx.MultiInput("LearningRate"); - PADDLE_ENFORCE_EQ( - n, - lr.size(), - platform::errors::InvalidArgument( - "The size of Input(LearningRate) must be equal to " - "Input(Param), but got the size of Input(LearningRate) " - "is %d, the size of Input(Param) is %d.", - lr.size(), - n)); - auto mom1 = ctx.MultiInput("Moment1"); - PADDLE_ENFORCE_EQ(n, - mom1.size(), - platform::errors::InvalidArgument( - "The size of Input(Moment1) must be equal to " - "Input(Param), but got the size of Input(Moment1) " - "is %d, the size of Input(Param) is %d.", - mom1.size(), - n)); - auto mom2 = ctx.MultiInput("Moment2"); - PADDLE_ENFORCE_EQ(n, - mom2.size(), - platform::errors::InvalidArgument( - "The size of Input(Moment2) must be equal to " - "Input(Param), but got the size of Input(Moment2) " - "is %d, the size of Input(Param) is %d.", - mom2.size(), - n)); - auto beta1_pow = ctx.MultiInput("Beta1Pow"); - PADDLE_ENFORCE_EQ(n, - beta1_pow.size(), - platform::errors::InvalidArgument( - "The size of Input(Beta1Pow) must be equal to " - "Input(Param), but got the size of Input(Beta1Pow) " - "is %d, the size of Input(Param) is %d.", - beta1_pow.size(), - n)); - auto beta2_pow = ctx.MultiInput("Beta2Pow"); - PADDLE_ENFORCE_EQ(n, - beta2_pow.size(), - platform::errors::InvalidArgument( - "The size of Input(Beta2Pow) must be equal to " - "Input(Param), but got the size of Input(Beta2Pow) " - "is %d, the size of Input(Param) is %d.", - beta2_pow.size(), - n)); - - auto param_out = ctx.MultiOutput("ParamOut"); - auto mom1_out = ctx.MultiOutput("Moment1Out"); - auto mom2_out = ctx.MultiOutput("Moment2Out"); - auto beta1_pow_out = ctx.MultiOutput("Beta1PowOut"); - auto beta2_pow_out = ctx.MultiOutput("Beta2PowOut"); - - T beta1 = static_cast(ctx.Attr("beta1")); - T beta2 = static_cast(ctx.Attr("beta2")); - T epsilon = static_cast(ctx.Attr("epsilon")); - bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); - VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; - - size_t param_num = param.size(); - for (size_t idx = 0; idx < param_num; idx++) { - phi::funcs::AdamFunctor functor( - beta1, - beta2, - epsilon, - beta1_pow[idx]->data(), - beta2_pow[idx]->data(), - mom1[idx]->data(), - mom1_out[idx]->mutable_data(ctx.GetPlace()), - mom2[idx]->data(), - mom2_out[idx]->mutable_data(ctx.GetPlace()), - lr[idx]->data(), - grad[idx]->data(), - param[idx]->data(), - param_out[idx]->mutable_data(ctx.GetPlace())); - functor(param[idx]->numel()); - if (!use_global_beta_pow) { - beta1_pow_out[idx]->mutable_data(ctx.GetPlace())[0] = - beta1 * beta1_pow[idx]->data()[0]; - beta2_pow_out[idx]->mutable_data(ctx.GetPlace())[0] = - beta2 * beta2_pow[idx]->data()[0]; - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 61c57981f94b51872c94a356b31dda7454148c8e..575e60923cd2146710982412b1e0f53f08cd919e 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1528,6 +1528,27 @@ void LogspaceInferMeta(const MetaTensor& start, out->set_dtype(start.dtype()); } +void MergedAdamInferMeta( + const std::vector& param, + const std::vector& grad, + const std::vector& learning_rate, + const std::vector& moment1, + const std::vector& moment2, + const std::vector& beta1_pow, + const std::vector& beta2_pow, + const paddle::optional>& master_param, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool multi_precision, + bool use_global_beta_pow, + std::vector param_out, + std::vector moment1_out, + std::vector moment2_out, + std::vector beta1_pow_out, + std::vector beta2_pow_out, + std::vector master_param_out) {} + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs) { const size_t inputs_num = inputs.size(); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 54c6fccceb9c1ebdbdb4a2ab109480d1a90c13ea..c0972816f3ba2816f312d9623ed937257ea60efb 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -234,6 +234,27 @@ void LogspaceInferMeta(const MetaTensor& start, const MetaTensor& base, MetaTensor* out); +void MergedAdamInferMeta( + const std::vector& param, + const std::vector& grad, + const std::vector& learning_rate, + const std::vector& moment1, + const std::vector& moment2, + const std::vector& beta1_pow, + const std::vector& beta2_pow, + const paddle::optional>& master_param, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool multi_precision, + bool use_global_beta_pow, + std::vector param_out, + std::vector moment1_out, + std::vector moment2_out, + std::vector beta1_pow_out, + std::vector beta2_pow_out, + std::vector master_param_out); + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs); diff --git a/paddle/phi/kernels/adam_kernel.h b/paddle/phi/kernels/adam_kernel.h index 0bdf05f8e5123ba26df7618c94d06a081dfa11e0..b1a7f5a686530cd591c4409f3ae2f30d7321b8d1 100644 --- a/paddle/phi/kernels/adam_kernel.h +++ b/paddle/phi/kernels/adam_kernel.h @@ -44,4 +44,27 @@ void AdamDenseKernel(const Context& dev_ctx, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs); +template +void MergedAdamKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& grad, + const std::vector& learning_rate, + const std::vector& moment1, + const std::vector& moment2, + const std::vector& beta1_pow, + const std::vector& beta2_pow, + const paddle::optional>& master_param, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool multi_precision, + bool use_global_beta_pow, + std::vector param_out, + std::vector moment1_out, + std::vector moment2_out, + std::vector beta1_pow_out, + std::vector beta2_pow_out, + std::vector master_param_out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/adam_kernel.cc b/paddle/phi/kernels/cpu/adam_kernel.cc index 03e2a539640ea6720c41f9158c76d3396d32736f..03a75bd36156f78a00ae758ed99cdb1185a8c782 100644 --- a/paddle/phi/kernels/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/cpu/adam_kernel.cc @@ -167,7 +167,111 @@ void AdamDenseKernel(const Context& dev_ctx, } } +template +void MergedAdamKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& grad, + const std::vector& learning_rate, + const std::vector& moment1, + const std::vector& moment2, + const std::vector& beta1_pow, + const std::vector& beta2_pow, + const paddle::optional>& master_param, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool multi_precision, + bool use_global_beta_pow, + std::vector param_out, + std::vector moment1_out, + std::vector moment2_out, + std::vector beta1_pow_out, + std::vector beta2_pow_out, + std::vector master_param_out) { + size_t param_num = param.size(); + PADDLE_ENFORCE_EQ( + param_num, + grad.size(), + errors::InvalidArgument("The size of Input(grad) must be equal to " + "Input(param), but got the size of Input(grad) " + "is %d, the size of Input(param) is %d.", + grad.size(), + param_num)); + PADDLE_ENFORCE_EQ( + param_num, + learning_rate.size(), + errors::InvalidArgument( + "The size of Input(learning_rate) must be equal to " + "Input(param), but got the size of Input(learning_rate) " + "is %d, the size of Input(param) is %d.", + learning_rate.size(), + param_num)); + PADDLE_ENFORCE_EQ(param_num, + moment1.size(), + errors::InvalidArgument( + "The size of Input(moment1) must be equal to " + "Input(param), but got the size of Input(moment1) " + "is %d, the size of Input(param) is %d.", + moment1.size(), + param_num)); + PADDLE_ENFORCE_EQ(param_num, + moment2.size(), + errors::InvalidArgument( + "The size of Input(moment2) must be equal to " + "Input(param), but got the size of Input(moment2) " + "is %d, the size of Input(param) is %d.", + moment2.size(), + param_num)); + PADDLE_ENFORCE_EQ(param_num, + beta1_pow.size(), + errors::InvalidArgument( + "The size of Input(beta1_pow) must be equal to " + "Input(param), but got the size of Input(beta1_pow) " + "is %d, the size of Input(param) is %d.", + beta1_pow.size(), + param_num)); + PADDLE_ENFORCE_EQ(param_num, + beta2_pow.size(), + errors::InvalidArgument( + "The size of Input(beta2_pow) must be equal to " + "Input(param), but got the size of Input(beta2_pow) " + "is %d, the size of Input(param) is %d.", + beta2_pow.size(), + param_num)); + T beta1_ = beta1.to(); + T beta2_ = beta2.to(); + T epsilon_ = epsilon.to(); + + for (size_t idx = 0; idx < param_num; idx++) { + phi::funcs::AdamFunctor functor( + beta1_, + beta2_, + epsilon_, + beta1_pow[idx]->data(), + beta2_pow[idx]->data(), + moment1[idx]->data(), + dev_ctx.template Alloc(moment1_out[idx]), + moment2[idx]->data(), + dev_ctx.template Alloc(moment2_out[idx]), + learning_rate[idx]->data(), + grad[idx]->data(), + param[idx]->data(), + dev_ctx.template Alloc(param_out[idx])); + functor(param[idx]->numel()); + if (!use_global_beta_pow) { + dev_ctx.template Alloc(beta1_pow_out[idx])[0] = + beta1_ * beta1_pow[idx]->data()[0]; + dev_ctx.template Alloc(beta2_pow_out[idx])[0] = + beta2_ * beta2_pow[idx]->data()[0]; + } + } +} + } // namespace phi PD_REGISTER_KERNEL(adam, CPU, ALL_LAYOUT, phi::AdamDenseKernel, float, double) { } + +PD_REGISTER_KERNEL( + merged_adam, CPU, ALL_LAYOUT, phi::MergedAdamKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 59aa4cf597e86606b759810c87eb1af11eea5c5c..b20e8610fefaf22f671d5342b0cffddf507f46bb 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -265,6 +265,106 @@ void AdamDenseKernel(const Context& dev_ctx, } } +template +void MergedAdamKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& grad, + const std::vector& learning_rate, + const std::vector& moment1, + const std::vector& moment2, + const std::vector& beta1_pow, + const std::vector& beta2_pow, + const paddle::optional>& master_param, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool multi_precision, + bool use_global_beta_pow, + std::vector param_out, + std::vector moment1_out, + std::vector moment2_out, + std::vector beta1_pow_out, + std::vector beta2_pow_out, + std::vector master_param_out) { + using MPDType = typename phi::dtype::MPTypeTrait::Type; + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + MPDType beta1_ = beta1.to(); + MPDType beta2_ = beta2.to(); + MPDType epsilon_ = epsilon.to(); + + size_t param_num = param.size(); + + for (size_t idx = 0; idx < param_num; idx++) { + const MPDType* master_in_data = + multi_precision ? master_param.get()[idx]->data() : nullptr; + MPDType* master_out_data = + multi_precision ? dev_ctx.template Alloc(master_param_out[idx]) + : nullptr; + + // update param and moment + int threads = 512; + int blocks = (param[idx]->numel() + threads - 1) / threads; + + if (beta1_pow[idx]->place() == CPUPlace() && + beta2_pow[idx]->place() == CPUPlace()) { + // Compute with betapow in REG + AdamKernelREG<<>>( + beta1_, + beta2_, + epsilon_, + *beta1_pow[idx]->data(), + *beta2_pow[idx]->data(), + moment1[idx]->data(), + dev_ctx.template Alloc(moment1_out[idx]), + moment2[idx]->data(), + dev_ctx.template Alloc(moment2_out[idx]), + learning_rate[idx]->data(), + grad[idx]->data(), + param[idx]->data(), + dev_ctx.template Alloc(param_out[idx]), + master_in_data, + master_out_data, + param[idx]->numel()); + if (!use_global_beta_pow) { + // Cpu update + dev_ctx.template HostAlloc(beta1_pow_out[idx])[0] = + beta1_ * beta1_pow[idx]->data()[0]; + dev_ctx.template HostAlloc(beta2_pow_out[idx])[0] = + beta2_ * beta2_pow[idx]->data()[0]; + } + } else { + AdamKernelMEM<<>>( + beta1_, + beta2_, + epsilon_, + beta1_pow[idx]->data(), + beta2_pow[idx]->data(), + moment1[idx]->data(), + dev_ctx.template Alloc(moment1_out[idx]), + moment2[idx]->data(), + dev_ctx.template Alloc(moment2_out[idx]), + learning_rate[idx]->data(), + grad[idx]->data(), + param[idx]->data(), + dev_ctx.template Alloc(param_out[idx]), + master_in_data, + master_out_data, + param[idx]->numel()); + if (!use_global_beta_pow) { + // Update with gpu + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1_, + beta2_, + beta1_pow[idx]->data(), + beta2_pow[idx]->data(), + dev_ctx.template Alloc(beta1_pow_out[idx]), + dev_ctx.template Alloc(beta2_pow_out[idx])); + } + } + } +} + } // namespace phi PD_REGISTER_KERNEL(adam, @@ -279,3 +379,15 @@ PD_REGISTER_KERNEL(adam, kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); } + +PD_REGISTER_KERNEL(merged_adam, + GPU, + ALL_LAYOUT, + phi::MergedAdamKernel, + float, + double, + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/ops/compat/merged_adam_sig.cc b/paddle/phi/ops/compat/merged_adam_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..38f56bad08d853470f9e416f421518aa8b653099 --- /dev/null +++ b/paddle/phi/ops/compat/merged_adam_sig.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +KernelSignature MergedAdamOpArgumentMapping(const ArgumentMappingContext& ctx) { + paddle::small_vector in_names = {"Param", + "Grad", + "LearningRate", + "Moment1", + "Moment2", + "Beta1Pow", + "Beta2Pow", + "MasterParam"}; + paddle::small_vector out_names = {"ParamOut", + "Moment1Out", + "Moment2Out", + "Beta1PowOut", + "Beta2PowOut", + "MasterParamOut"}; + paddle::small_vector attr_names = { + "beta1", "beta2", "epsilon", "multi_precision", "use_global_beta_pow"}; + + return KernelSignature("merged_adam", + std::move(in_names), + std::move(attr_names), + std::move(out_names)); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(merged_adam, phi::MergedAdamOpArgumentMapping);