未验证 提交 27e252d9 编写于 作者: T taixiurong 提交者: GitHub

add adamw suppor xpu, test=kunlun (#48114)

上级 394a7179
......@@ -34,7 +34,9 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"adadelta", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"adam",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
......
......@@ -52,6 +52,7 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
using XPUType = typename XPUTypeTrait<T>::Type;
bool skip_update_ = false;
if (skip_update.is_initialized()) {
PADDLE_ENFORCE_EQ(
......@@ -89,40 +90,42 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_pow_ptr = xpu_beta2_pow.template data<float>();
}
if (with_decay) {
int r = xpu::adamw(dev_ctx.x_context(),
grad.template data<T>(),
moment1.template data<float>(),
moment2.template data<float>(),
param.template data<T>(),
beta1_pow_ptr,
beta2_pow_ptr,
learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out),
dev_ctx.template Alloc<T>(param_out),
beta1_,
beta2_,
epsilon_,
coeff,
param.numel());
int r = xpu::adamw(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad.template data<T>()),
moment1.template data<float>(),
moment2.template data<float>(),
reinterpret_cast<const XPUType*>(param.template data<T>()),
beta1_pow_ptr,
beta2_pow_ptr,
learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out),
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
beta1_,
beta2_,
epsilon_,
coeff,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
} else {
int r = xpu::adam(dev_ctx.x_context(),
grad.template data<T>(),
moment1.template data<float>(),
moment2.template data<float>(),
param.template data<T>(),
beta1_pow_ptr,
beta2_pow_ptr,
learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out),
dev_ctx.template Alloc<T>(param_out),
beta1_,
beta2_,
epsilon_,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
int r = xpu::adam(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad.template data<T>()),
moment1.template data<float>(),
moment2.template data<float>(),
reinterpret_cast<const XPUType*>(param.template data<T>()),
beta1_pow_ptr,
beta2_pow_ptr,
learning_rate.template data<float>(),
dev_ctx.template Alloc<float>(moment1_out),
dev_ctx.template Alloc<float>(moment2_out),
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
beta1_,
beta2_,
epsilon_,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
if (!use_global_beta_pow) {
......@@ -145,7 +148,7 @@ void AdamwDenseKernel(const Context& dev_ctx,
false,
beta1_,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::scale(dev_ctx.x_context(),
beta2_pow_ptr,
beta2_pow_out_p,
......@@ -153,14 +156,15 @@ void AdamwDenseKernel(const Context& dev_ctx,
false,
beta2_,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(adamw, XPU, ALL_LAYOUT, phi::AdamwDenseKernel, float) {
PD_REGISTER_KERNEL(
adamw, XPU, ALL_LAYOUT, phi::AdamwDenseKernel, float, phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
......
......@@ -90,9 +90,9 @@ class XPUTestAdamwOp1(XPUOpTestWrapper):
self.dtype = self.in_type_str
param = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
grad = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
moment1 = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
moment1 = np.random.uniform(-1, 1, self.shape).astype("float32")
# The second moment is positive
moment2 = np.random.random(self.shape).astype(self.dtype)
moment2 = np.random.random(self.shape).astype("float32")
learning_rate = 0.004
beta1 = 0.78
......@@ -106,9 +106,9 @@ class XPUTestAdamwOp1(XPUOpTestWrapper):
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype(self.dtype),
'Beta1Pow': np.array([beta1_pow]).astype(self.dtype),
'Beta2Pow': np.array([beta2_pow]).astype(self.dtype),
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
}
self.attrs = {
......@@ -127,8 +127,8 @@ class XPUTestAdamwOp1(XPUOpTestWrapper):
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype(self.dtype) * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype(self.dtype) * beta2,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2,
}
def init_shape(self):
......@@ -305,7 +305,8 @@ class XPUTestAdamwOp2(XPUOpTestWrapper):
support_types = get_xpu_op_support_types('adamw')
for stype in support_types:
create_test_class(globals(), XPUTestAdamwOp1, stype)
create_test_class(globals(), XPUTestAdamwOp2, stype)
if stype == "float32":
create_test_class(globals(), XPUTestAdamwOp2, stype)
if __name__ == "__main__":
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册