未验证 提交 d55ee95f 编写于 作者: Z zhangbo9674 提交者: GitHub

[Phi] Migrate merged_adam_op into Phi (#44184)

* remov merged_adam_op to phi

* refine code
上级 636c6347
......@@ -10,7 +10,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/merged_adam_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -21,8 +25,6 @@ class MergedAdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto param_dtype =
......@@ -128,13 +130,15 @@ $$
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(merged_adam,
ops::MergedAdamOp,
ops::MergedAdamOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(merged_adamw,
ops::MergedAdamOp,
ops::MergedAdamOpMaker);
REGISTER_OP_CPU_KERNEL(merged_adam,
ops::MergedAdamOpKernel<phi::CPUContext, float>,
ops::MergedAdamOpKernel<phi::CPUContext, double>);
DECLARE_INFER_SHAPE_FUNCTOR(merged_adam,
MergedAdamInferMetaFunctor,
PD_INFER_META(phi::MergedAdamInferMeta));
REGISTER_OPERATOR(
merged_adam,
ops::MergedAdamOp,
ops::MergedAdamOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
MergedAdamInferMetaFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/merged_adam_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
namespace paddle {
namespace operators {
template <typename T, typename MT>
__global__ void AdamKernelREG(MT beta1,
MT beta2,
MT epsilon,
MT beta1_pow_,
MT beta2_pow_,
const MT* moment1,
MT* moment1_out,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const T* param,
T* param_out,
const MT* master_param,
MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T, typename MT>
__global__ void AdamKernelMEM(MT beta1,
MT beta2,
MT epsilon,
const MT* beta1_pow_,
const MT* beta2_pow_,
const MT* moment1,
MT* moment1_out,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const T* param,
T* param_out,
const MT* master_param,
MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T>
__global__ void UpdateBetaPow(T beta1,
T beta2,
const T* beta1_pow_,
const T* beta2_pow_,
T* beta1_pow_out,
T* beta2_pow_out) {
*beta1_pow_out = beta1 * beta1_pow_[0];
*beta2_pow_out = beta2 * beta2_pow_[0];
}
template <typename T>
class MergedAdamOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using MPDType = typename details::MPTypeTrait<T>::Type;
auto param = ctx.MultiInput<framework::Tensor>("Param");
auto grad = ctx.MultiInput<framework::Tensor>("Grad");
auto lr = ctx.MultiInput<framework::Tensor>("LearningRate");
auto mom1 = ctx.MultiInput<framework::Tensor>("Moment1");
auto mom2 = ctx.MultiInput<framework::Tensor>("Moment2");
auto beta1_pow = ctx.MultiInput<framework::Tensor>("Beta1Pow");
auto beta2_pow = ctx.MultiInput<framework::Tensor>("Beta2Pow");
auto param_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
auto mom1_out = ctx.MultiOutput<framework::Tensor>("Moment1Out");
auto mom2_out = ctx.MultiOutput<framework::Tensor>("Moment2Out");
auto beta1_pow_out = ctx.MultiOutput<framework::Tensor>("Beta1PowOut");
auto beta2_pow_out = ctx.MultiOutput<framework::Tensor>("Beta2PowOut");
MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
MPDType beta2 = static_cast<MPDType>(ctx.Attr<float>("beta2"));
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
const bool multi_precision = ctx.Attr<bool>("multi_precision");
auto master_param = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_param_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
size_t param_num = param.size();
for (size_t idx = 0; idx < param_num; idx++) {
const MPDType* master_in_data =
multi_precision ? master_param[idx]->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out[idx]->mutable_data<MPDType>(ctx.GetPlace())
: nullptr;
// update param and moment
int threads = 512;
int blocks = (param[idx]->numel() + threads - 1) / threads;
if (beta1_pow[idx]->place() == platform::CPUPlace() &&
beta2_pow[idx]->place() == platform::CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1,
beta2,
epsilon,
*beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(),
mom1[idx]->data<MPDType>(),
mom1_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
mom2[idx]->data<MPDType>(),
mom2_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
lr[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
param_out[idx]->mutable_data<T>(ctx.GetPlace()),
master_in_data,
master_out_data,
param[idx]->numel());
if (!use_global_beta_pow) {
// Cpu update
beta1_pow_out[idx]->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta1 * beta1_pow[idx]->data<MPDType>()[0];
beta2_pow_out[idx]->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta2 * beta2_pow[idx]->data<MPDType>()[0];
}
} else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1,
beta2,
epsilon,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
mom1[idx]->data<MPDType>(),
mom1_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
mom2[idx]->data<MPDType>(),
mom2_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
lr[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
param_out[idx]->mutable_data<T>(ctx.GetPlace()),
master_in_data,
master_out_data,
param[idx]->numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
beta1,
beta2,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
beta1_pow_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
beta2_pow_out[idx]->mutable_data<MPDType>(ctx.GetPlace()));
}
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(merged_adam,
ops::MergedAdamOpCUDAKernel<float>,
ops::MergedAdamOpCUDAKernel<double>,
ops::MergedAdamOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/kernels/funcs/adam_functors.h"
namespace paddle {
namespace operators {
namespace scatter = paddle::operators::math::scatter;
template <typename DeviceContext, typename T>
class MergedAdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.MultiInput<framework::Tensor>("Param");
size_t n = param.size();
auto grad = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(n,
grad.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to "
"Input(Param), but got the size of Input(Grad) "
"is %d, the size of Input(Param) is %d.",
grad.size(),
n));
auto lr = ctx.MultiInput<framework::Tensor>("LearningRate");
PADDLE_ENFORCE_EQ(
n,
lr.size(),
platform::errors::InvalidArgument(
"The size of Input(LearningRate) must be equal to "
"Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lr.size(),
n));
auto mom1 = ctx.MultiInput<framework::Tensor>("Moment1");
PADDLE_ENFORCE_EQ(n,
mom1.size(),
platform::errors::InvalidArgument(
"The size of Input(Moment1) must be equal to "
"Input(Param), but got the size of Input(Moment1) "
"is %d, the size of Input(Param) is %d.",
mom1.size(),
n));
auto mom2 = ctx.MultiInput<framework::Tensor>("Moment2");
PADDLE_ENFORCE_EQ(n,
mom2.size(),
platform::errors::InvalidArgument(
"The size of Input(Moment2) must be equal to "
"Input(Param), but got the size of Input(Moment2) "
"is %d, the size of Input(Param) is %d.",
mom2.size(),
n));
auto beta1_pow = ctx.MultiInput<framework::Tensor>("Beta1Pow");
PADDLE_ENFORCE_EQ(n,
beta1_pow.size(),
platform::errors::InvalidArgument(
"The size of Input(Beta1Pow) must be equal to "
"Input(Param), but got the size of Input(Beta1Pow) "
"is %d, the size of Input(Param) is %d.",
beta1_pow.size(),
n));
auto beta2_pow = ctx.MultiInput<framework::Tensor>("Beta2Pow");
PADDLE_ENFORCE_EQ(n,
beta2_pow.size(),
platform::errors::InvalidArgument(
"The size of Input(Beta2Pow) must be equal to "
"Input(Param), but got the size of Input(Beta2Pow) "
"is %d, the size of Input(Param) is %d.",
beta2_pow.size(),
n));
auto param_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
auto mom1_out = ctx.MultiOutput<framework::Tensor>("Moment1Out");
auto mom2_out = ctx.MultiOutput<framework::Tensor>("Moment2Out");
auto beta1_pow_out = ctx.MultiOutput<framework::Tensor>("Beta1PowOut");
auto beta2_pow_out = ctx.MultiOutput<framework::Tensor>("Beta2PowOut");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
size_t param_num = param.size();
for (size_t idx = 0; idx < param_num; idx++) {
phi::funcs::AdamFunctor<T, phi::funcs::CPUAdam> functor(
beta1,
beta2,
epsilon,
beta1_pow[idx]->data<T>(),
beta2_pow[idx]->data<T>(),
mom1[idx]->data<T>(),
mom1_out[idx]->mutable_data<T>(ctx.GetPlace()),
mom2[idx]->data<T>(),
mom2_out[idx]->mutable_data<T>(ctx.GetPlace()),
lr[idx]->data<T>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
param_out[idx]->mutable_data<T>(ctx.GetPlace()));
functor(param[idx]->numel());
if (!use_global_beta_pow) {
beta1_pow_out[idx]->mutable_data<T>(ctx.GetPlace())[0] =
beta1 * beta1_pow[idx]->data<T>()[0];
beta2_pow_out[idx]->mutable_data<T>(ctx.GetPlace())[0] =
beta2 * beta2_pow[idx]->data<T>()[0];
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -1528,6 +1528,27 @@ void LogspaceInferMeta(const MetaTensor& start,
out->set_dtype(start.dtype());
}
void MergedAdamInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& moment1,
const std::vector<const MetaTensor*>& moment2,
const std::vector<const MetaTensor*>& beta1_pow,
const std::vector<const MetaTensor*>& beta2_pow,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> moment1_out,
std::vector<MetaTensor*> moment2_out,
std::vector<MetaTensor*> beta1_pow_out,
std::vector<MetaTensor*> beta2_pow_out,
std::vector<MetaTensor*> master_param_out) {}
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
......
......@@ -234,6 +234,27 @@ void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& base,
MetaTensor* out);
void MergedAdamInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& moment1,
const std::vector<const MetaTensor*>& moment2,
const std::vector<const MetaTensor*>& beta1_pow,
const std::vector<const MetaTensor*>& beta2_pow,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> moment1_out,
std::vector<MetaTensor*> moment2_out,
std::vector<MetaTensor*> beta1_pow_out,
std::vector<MetaTensor*> beta2_pow_out,
std::vector<MetaTensor*> master_param_out);
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);
......
......@@ -44,4 +44,27 @@ void AdamDenseKernel(const Context& dev_ctx,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs);
template <typename T, typename Context>
void MergedAdamKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& learning_rate,
const std::vector<const DenseTensor*>& moment1,
const std::vector<const DenseTensor*>& moment2,
const std::vector<const DenseTensor*>& beta1_pow,
const std::vector<const DenseTensor*>& beta2_pow,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> moment1_out,
std::vector<DenseTensor*> moment2_out,
std::vector<DenseTensor*> beta1_pow_out,
std::vector<DenseTensor*> beta2_pow_out,
std::vector<DenseTensor*> master_param_out);
} // namespace phi
......@@ -167,7 +167,111 @@ void AdamDenseKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void MergedAdamKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& learning_rate,
const std::vector<const DenseTensor*>& moment1,
const std::vector<const DenseTensor*>& moment2,
const std::vector<const DenseTensor*>& beta1_pow,
const std::vector<const DenseTensor*>& beta2_pow,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> moment1_out,
std::vector<DenseTensor*> moment2_out,
std::vector<DenseTensor*> beta1_pow_out,
std::vector<DenseTensor*> beta2_pow_out,
std::vector<DenseTensor*> master_param_out) {
size_t param_num = param.size();
PADDLE_ENFORCE_EQ(
param_num,
grad.size(),
errors::InvalidArgument("The size of Input(grad) must be equal to "
"Input(param), but got the size of Input(grad) "
"is %d, the size of Input(param) is %d.",
grad.size(),
param_num));
PADDLE_ENFORCE_EQ(
param_num,
learning_rate.size(),
errors::InvalidArgument(
"The size of Input(learning_rate) must be equal to "
"Input(param), but got the size of Input(learning_rate) "
"is %d, the size of Input(param) is %d.",
learning_rate.size(),
param_num));
PADDLE_ENFORCE_EQ(param_num,
moment1.size(),
errors::InvalidArgument(
"The size of Input(moment1) must be equal to "
"Input(param), but got the size of Input(moment1) "
"is %d, the size of Input(param) is %d.",
moment1.size(),
param_num));
PADDLE_ENFORCE_EQ(param_num,
moment2.size(),
errors::InvalidArgument(
"The size of Input(moment2) must be equal to "
"Input(param), but got the size of Input(moment2) "
"is %d, the size of Input(param) is %d.",
moment2.size(),
param_num));
PADDLE_ENFORCE_EQ(param_num,
beta1_pow.size(),
errors::InvalidArgument(
"The size of Input(beta1_pow) must be equal to "
"Input(param), but got the size of Input(beta1_pow) "
"is %d, the size of Input(param) is %d.",
beta1_pow.size(),
param_num));
PADDLE_ENFORCE_EQ(param_num,
beta2_pow.size(),
errors::InvalidArgument(
"The size of Input(beta2_pow) must be equal to "
"Input(param), but got the size of Input(beta2_pow) "
"is %d, the size of Input(param) is %d.",
beta2_pow.size(),
param_num));
T beta1_ = beta1.to<T>();
T beta2_ = beta2.to<T>();
T epsilon_ = epsilon.to<T>();
for (size_t idx = 0; idx < param_num; idx++) {
phi::funcs::AdamFunctor<T, phi::funcs::CPUAdam> functor(
beta1_,
beta2_,
epsilon_,
beta1_pow[idx]->data<T>(),
beta2_pow[idx]->data<T>(),
moment1[idx]->data<T>(),
dev_ctx.template Alloc<T>(moment1_out[idx]),
moment2[idx]->data<T>(),
dev_ctx.template Alloc<T>(moment2_out[idx]),
learning_rate[idx]->data<T>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]));
functor(param[idx]->numel());
if (!use_global_beta_pow) {
dev_ctx.template Alloc<T>(beta1_pow_out[idx])[0] =
beta1_ * beta1_pow[idx]->data<T>()[0];
dev_ctx.template Alloc<T>(beta2_pow_out[idx])[0] =
beta2_ * beta2_pow[idx]->data<T>()[0];
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(adam, CPU, ALL_LAYOUT, phi::AdamDenseKernel, float, double) {
}
PD_REGISTER_KERNEL(
merged_adam, CPU, ALL_LAYOUT, phi::MergedAdamKernel, float, double) {}
......@@ -265,6 +265,106 @@ void AdamDenseKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void MergedAdamKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& learning_rate,
const std::vector<const DenseTensor*>& moment1,
const std::vector<const DenseTensor*>& moment2,
const std::vector<const DenseTensor*>& beta1_pow,
const std::vector<const DenseTensor*>& beta2_pow,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> moment1_out,
std::vector<DenseTensor*> moment2_out,
std::vector<DenseTensor*> beta1_pow_out,
std::vector<DenseTensor*> beta2_pow_out,
std::vector<DenseTensor*> master_param_out) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
MPDType beta1_ = beta1.to<MPDType>();
MPDType beta2_ = beta2.to<MPDType>();
MPDType epsilon_ = epsilon.to<MPDType>();
size_t param_num = param.size();
for (size_t idx = 0; idx < param_num; idx++) {
const MPDType* master_in_data =
multi_precision ? master_param.get()[idx]->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision ? dev_ctx.template Alloc<MPDType>(master_param_out[idx])
: nullptr;
// update param and moment
int threads = 512;
int blocks = (param[idx]->numel() + threads - 1) / threads;
if (beta1_pow[idx]->place() == CPUPlace() &&
beta2_pow[idx]->place() == CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out[idx])[0] =
beta1_ * beta1_pow[idx]->data<MPDType>()[0];
dev_ctx.template HostAlloc<MPDType>(beta2_pow_out[idx])[0] =
beta2_ * beta2_pow[idx]->data<MPDType>()[0];
}
} else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(beta1_pow_out[idx]),
dev_ctx.template Alloc<MPDType>(beta2_pow_out[idx]));
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(adam,
......@@ -279,3 +379,15 @@ PD_REGISTER_KERNEL(adam,
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(merged_adam,
GPU,
ALL_LAYOUT,
phi::MergedAdamKernel,
float,
double,
phi::dtype::float16) {
// Skip beta1_pow, beta2_pow data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace phi {
KernelSignature MergedAdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::small_vector<const char*> in_names = {"Param",
"Grad",
"LearningRate",
"Moment1",
"Moment2",
"Beta1Pow",
"Beta2Pow",
"MasterParam"};
paddle::small_vector<const char*> out_names = {"ParamOut",
"Moment1Out",
"Moment2Out",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"};
paddle::small_vector<const char*> attr_names = {
"beta1", "beta2", "epsilon", "multi_precision", "use_global_beta_pow"};
return KernelSignature("merged_adam",
std::move(in_names),
std::move(attr_names),
std::move(out_names));
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(merged_adam, phi::MergedAdamOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册