未验证 提交 f5a041e6 编写于 作者: A Aurelius84 提交者: GitHub

[XPU]Migrate adamw XPU kernel into Phi (#45609)

* [XPU]Migrate adamw XPU kernel into Phi

* test=kunlun

* test=kunlun
上级 02afb925
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/adam_op_functor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#ifdef PADDLE_WITH_XPU
template <typename T>
class AdamwOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(),
true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
auto& param = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Param"), "Input", "Param", "Adam");
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Moment1"), "Input", "Moment1", "Adam");
auto& mom2 = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Moment2"), "Input", "Moment2", "Adam");
auto& lr = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("LearningRate"), "Input", "LearningRate", "Adam");
auto& beta1_pow = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Beta1Pow"), "Input", "Beta1Pow", "Adam");
auto& beta2_pow = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Beta2Pow"), "Input", "Beta2Pow", "Adam");
auto& param_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Adam");
auto& mom1_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("Moment1Out"), "Output", "Moment1Out", "Adam");
auto& mom2_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("Moment2Out"), "Output", "Moment2Out", "Adam");
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(),
1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(
*skip_update_tensor, ctx.device_context(), &skip_update_vec);
skip_update = skip_update_vec[0];
}
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "Adam skip update";
framework::TensorCopy(param, ctx.GetPlace(), dev_ctx, &param_out);
framework::TensorCopy(mom1, ctx.GetPlace(), dev_ctx, &mom1_out);
framework::TensorCopy(mom2, ctx.GetPlace(), dev_ctx, &mom2_out);
framework::TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, beta1_pow_out);
framework::TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, beta2_pow_out);
return;
}
bool with_decay = ctx.Attr<bool>("with_decay");
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(),
1,
platform::errors::InvalidArgument(
"Tensor holds the wrong size, Expected beta1 pow "
"output size is 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(),
1,
platform::errors::InvalidArgument(
"Tensor holds the wrong size, Expected beta2 pow "
"output size is 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
float beta1 = static_cast<float>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
beta1 = static_cast<float>(GetAttrFromTensor(beta1_tensor));
}
float beta2 = static_cast<float>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<float>(GetAttrFromTensor(beta2_tensor));
}
float epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
epsilon = static_cast<float>(GetAttrFromTensor(epsilon_tensor));
}
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Grad"), "Input", "Grad", "Adam");
const float* beta1_pow_ptr = beta1_pow.template data<float>();
const float* beta2_pow_ptr = beta2_pow.template data<float>();
Tensor xpu_beta1_pow;
Tensor xpu_beta2_pow;
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
paddle::framework::TensorCopy(
beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow);
paddle::framework::TensorCopy(
beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow);
dev_ctx.Wait();
beta1_pow_ptr = xpu_beta1_pow.template data<float>();
beta2_pow_ptr = xpu_beta2_pow.template data<float>();
}
if (with_decay) {
float coeff = ctx.Attr<float>("coeff");
int r =
xpu::adamw(dev_ctx.x_context(),
grad.template data<T>(),
mom1.template data<float>(),
mom2.template data<float>(),
param.template data<T>(),
beta1_pow_ptr,
beta2_pow_ptr,
lr.template data<float>(),
mom1_out.template mutable_data<float>(ctx.GetPlace()),
mom2_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()),
beta1,
beta2,
epsilon,
coeff,
param.numel());
PADDLE_ENFORCE_EQ(
r,
xpu::SUCCESS,
platform::errors::External(
"XPU kernel adamw occur error in adamw error code ",
r,
XPUAPIErrorMsg[r]));
} else {
int r = xpu::adam(dev_ctx.x_context(),
grad.template data<T>(),
mom1.template data<float>(),
mom2.template data<float>(),
param.template data<T>(),
beta1_pow_ptr,
beta2_pow_ptr,
lr.template data<float>(),
mom1_out.template mutable_data<float>(ctx.GetPlace()),
mom2_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()),
beta1,
beta2,
epsilon,
param.numel());
PADDLE_ENFORCE_EQ(
r,
xpu::SUCCESS,
platform::errors::External(
"XPU kernel adam occur error in adamw error code ",
r,
XPUAPIErrorMsg[r]));
}
if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
const float* beta1_pow_p = beta1_pow.template data<float>();
beta1_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta1 * beta1_pow_p[0];
const float* beta2_pow_p = beta2_pow.template data<float>();
beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0];
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else {
float* beta1_pow_out_p =
beta1_pow_out->mutable_data<float>(ctx.GetPlace());
float* beta2_pow_out_p =
beta2_pow_out->mutable_data<float>(ctx.GetPlace());
int r = xpu::scale(dev_ctx.x_context(),
beta1_pow_ptr,
beta1_pow_out_p,
beta1_pow.numel(),
false,
beta1,
0.0f);
PADDLE_ENFORCE_EQ(
r,
xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adamw error code ",
r,
XPUAPIErrorMsg[r]));
r = xpu::scale(dev_ctx.x_context(),
beta2_pow_ptr,
beta2_pow_out_p,
beta2_pow.numel(),
false,
beta2,
0.0f);
PADDLE_ENFORCE_EQ(
r,
xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adamw error code ",
r,
XPUAPIErrorMsg[r]));
}
}
} else {
PADDLE_ENFORCE_EQ(1,
2,
platform::errors::InvalidArgument(
"Variable type not supported by adamw_op"));
}
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(adamw, ops::AdamwOpXPUKernel<float>);
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adamw_kernel.h"
#include <vector>
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
// for TensorToVector
#include "paddle/fluid/framework/tensor_util.h"
namespace phi {
template <typename T, typename Context>
void AdamwDenseKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
float lr_ratio,
float coeff,
bool with_decay,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
DenseTensor* param_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
bool skip_update_ = false;
if (skip_update.is_initialized()) {
PADDLE_ENFORCE_EQ(
skip_update->numel(),
1,
errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d",
skip_update->numel()));
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec);
skip_update_ = skip_update_vec[0];
}
if (skip_update_) {
VLOG(4) << "Adamw skip update";
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
auto beta1_ = beta1.to<float>();
auto beta2_ = beta2.to<float>();
auto epsilon_ = epsilon.to<float>();
const float* beta1_pow_ptr = beta1_pow.template data<float>();
const float* beta2_pow_ptr = beta2_pow.template data<float>();
DenseTensor xpu_beta1_pow;
DenseTensor xpu_beta2_pow;
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, &xpu_beta1_pow);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, &xpu_beta2_pow);
dev_ctx.Wait();
beta1_pow_ptr = xpu_beta1_pow.template data<float>();
beta2_pow_ptr = xpu_beta2_pow.template data<float>();
}
if (with_decay) {
int r = xpu::adamw(dev_ctx.x_context(),
grad.template data<T>(),
moment1.template data<float>(),
moment2.template data<float>(),
param.template data<T>(),
beta1_pow_ptr,
beta2_pow_ptr,
learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out),
dev_ctx.template Alloc<T>(param_out),
beta1_,
beta2_,
epsilon_,
coeff,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
} else {
int r = xpu::adam(dev_ctx.x_context(),
grad.template data<T>(),
moment1.template data<float>(),
moment2.template data<float>(),
param.template data<T>(),
beta1_pow_ptr,
beta2_pow_ptr,
learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out),
dev_ctx.template Alloc<T>(param_out),
beta1_,
beta2_,
epsilon_,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
}
if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
const float* beta1_pow_p = beta1_pow.template data<float>();
dev_ctx.template HostAlloc<float>(beta1_pow_out)[0] =
beta1_ * beta1_pow_p[0];
const float* beta2_pow_p = beta2_pow.template data<float>();
dev_ctx.template HostAlloc<float>(beta2_pow_out)[0] =
beta2_ * beta2_pow_p[0];
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else {
float* beta1_pow_out_p = dev_ctx.template Alloc<float>(beta1_pow_out);
float* beta2_pow_out_p = dev_ctx.template Alloc<float>(beta2_pow_out);
int r = xpu::scale(dev_ctx.x_context(),
beta1_pow_ptr,
beta1_pow_out_p,
beta1_pow.numel(),
false,
beta1_,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
r = xpu::scale(dev_ctx.x_context(),
beta2_pow_ptr,
beta2_pow_out_p,
beta2_pow.numel(),
false,
beta2_,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(adamw, XPU, ALL_LAYOUT, phi::AdamwDenseKernel, float) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册