未验证 提交 27e252d9 编写于 作者: T taixiurong 提交者: GitHub

add adamw suppor xpu, test=kunlun (#48114)

上级 394a7179
...@@ -34,7 +34,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -34,7 +34,9 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"adadelta", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adadelta", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adamw",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"adam", {"adam",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
...@@ -52,6 +52,7 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -52,6 +52,7 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* beta1_pow_out, DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out, DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) { DenseTensor* master_param_outs) {
using XPUType = typename XPUTypeTrait<T>::Type;
bool skip_update_ = false; bool skip_update_ = false;
if (skip_update.is_initialized()) { if (skip_update.is_initialized()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -89,17 +90,18 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -89,17 +90,18 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_pow_ptr = xpu_beta2_pow.template data<float>(); beta2_pow_ptr = xpu_beta2_pow.template data<float>();
} }
if (with_decay) { if (with_decay) {
int r = xpu::adamw(dev_ctx.x_context(), int r = xpu::adamw(
grad.template data<T>(), dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad.template data<T>()),
moment1.template data<float>(), moment1.template data<float>(),
moment2.template data<float>(), moment2.template data<float>(),
param.template data<T>(), reinterpret_cast<const XPUType*>(param.template data<T>()),
beta1_pow_ptr, beta1_pow_ptr,
beta2_pow_ptr, beta2_pow_ptr,
learning_rate.template data<float>(), learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out), dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out), dev_ctx.template Alloc<float>(moment2_out),
dev_ctx.template Alloc<T>(param_out), reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
beta1_, beta1_,
beta2_, beta2_,
epsilon_, epsilon_,
...@@ -107,22 +109,23 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -107,22 +109,23 @@ void AdamwDenseKernel(const Context& dev_ctx,
param.numel()); param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
} else { } else {
int r = xpu::adam(dev_ctx.x_context(), int r = xpu::adam(
grad.template data<T>(), dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad.template data<T>()),
moment1.template data<float>(), moment1.template data<float>(),
moment2.template data<float>(), moment2.template data<float>(),
param.template data<T>(), reinterpret_cast<const XPUType*>(param.template data<T>()),
beta1_pow_ptr, beta1_pow_ptr,
beta2_pow_ptr, beta2_pow_ptr,
learning_rate.template data<float>(), learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out), dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out), dev_ctx.template Alloc<float>(moment2_out),
dev_ctx.template Alloc<T>(param_out), reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
beta1_, beta1_,
beta2_, beta2_,
epsilon_, epsilon_,
param.numel()); param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
} }
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
...@@ -145,7 +148,7 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -145,7 +148,7 @@ void AdamwDenseKernel(const Context& dev_ctx,
false, false,
beta1_, beta1_,
0.0f); 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::scale(dev_ctx.x_context(), r = xpu::scale(dev_ctx.x_context(),
beta2_pow_ptr, beta2_pow_ptr,
beta2_pow_out_p, beta2_pow_out_p,
...@@ -153,14 +156,15 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -153,14 +156,15 @@ void AdamwDenseKernel(const Context& dev_ctx,
false, false,
beta2_, beta2_,
0.0f); 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
} }
} }
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(adamw, XPU, ALL_LAYOUT, phi::AdamwDenseKernel, float) { PD_REGISTER_KERNEL(
adamw, XPU, ALL_LAYOUT, phi::AdamwDenseKernel, float, phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform // Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
......
...@@ -90,9 +90,9 @@ class XPUTestAdamwOp1(XPUOpTestWrapper): ...@@ -90,9 +90,9 @@ class XPUTestAdamwOp1(XPUOpTestWrapper):
self.dtype = self.in_type_str self.dtype = self.in_type_str
param = np.random.uniform(-1, 1, self.shape).astype(self.dtype) param = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
grad = np.random.uniform(-1, 1, self.shape).astype(self.dtype) grad = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
moment1 = np.random.uniform(-1, 1, self.shape).astype(self.dtype) moment1 = np.random.uniform(-1, 1, self.shape).astype("float32")
# The second moment is positive # The second moment is positive
moment2 = np.random.random(self.shape).astype(self.dtype) moment2 = np.random.random(self.shape).astype("float32")
learning_rate = 0.004 learning_rate = 0.004
beta1 = 0.78 beta1 = 0.78
...@@ -106,9 +106,9 @@ class XPUTestAdamwOp1(XPUOpTestWrapper): ...@@ -106,9 +106,9 @@ class XPUTestAdamwOp1(XPUOpTestWrapper):
'Grad': grad, 'Grad': grad,
'Moment1': moment1, 'Moment1': moment1,
'Moment2': moment2, 'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype(self.dtype), 'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype(self.dtype), 'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype(self.dtype), 'Beta2Pow': np.array([beta2_pow]).astype("float32"),
} }
self.attrs = { self.attrs = {
...@@ -127,8 +127,8 @@ class XPUTestAdamwOp1(XPUOpTestWrapper): ...@@ -127,8 +127,8 @@ class XPUTestAdamwOp1(XPUOpTestWrapper):
'Moment1Out': moment1_out, 'Moment1Out': moment1_out,
'Moment2Out': moment2_out, 'Moment2Out': moment2_out,
'ParamOut': param_out, 'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype(self.dtype) * beta1, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype(self.dtype) * beta2, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2,
} }
def init_shape(self): def init_shape(self):
...@@ -305,6 +305,7 @@ class XPUTestAdamwOp2(XPUOpTestWrapper): ...@@ -305,6 +305,7 @@ class XPUTestAdamwOp2(XPUOpTestWrapper):
support_types = get_xpu_op_support_types('adamw') support_types = get_xpu_op_support_types('adamw')
for stype in support_types: for stype in support_types:
create_test_class(globals(), XPUTestAdamwOp1, stype) create_test_class(globals(), XPUTestAdamwOp1, stype)
if stype == "float32":
create_test_class(globals(), XPUTestAdamwOp2, stype) create_test_class(globals(), XPUTestAdamwOp2, stype)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册