From dda74715cb08a4f3a7eaa82a6ba38da975b1c514 Mon Sep 17 00:00:00 2001 From: taixiurong Date: Tue, 28 Feb 2023 10:56:52 +0800 Subject: [PATCH] =?UTF-8?q?xpu-paddlepaddle-57=20[=E4=BB=BB=E5=8A=A1]=20ad?= =?UTF-8?q?amw=20lr=5Fradio=E6=94=AF=E6=8C=81=20(#50979)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/xpu/adamw_kernel.cc | 74 ++-- .../tests/unittests/xpu/test_adamw_op_xpu.py | 340 ++++++++++++++++++ python/paddle/optimizer/adamw.py | 7 +- 3 files changed, 380 insertions(+), 41 deletions(-) diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index 0e27f686ada..c2f0f66b343 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -87,47 +87,43 @@ void AdamwDenseKernel(const Context& dev_ctx, beta1_pow_ptr = xpu_beta1_pow.template data(); beta2_pow_ptr = xpu_beta2_pow.template data(); } - if (with_decay) { - int r = xpu::adamw( - dev_ctx.x_context(), - reinterpret_cast(grad.template data()), - moment1.template data(), - moment2.template data(), - reinterpret_cast(param.template data()), - beta1_pow_ptr, - beta2_pow_ptr, - learning_rate.template data(), - dev_ctx.template Alloc(moment1_out), - dev_ctx.template Alloc(moment2_out), - reinterpret_cast(dev_ctx.template Alloc(param_out)), - beta1_, - beta2_, - epsilon_, - coeff, - param.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); - } else { - int r = xpu::adam( - dev_ctx.x_context(), - reinterpret_cast(grad.template data()), - moment1.template data(), - moment2.template data(), - reinterpret_cast(param.template data()), - beta1_pow_ptr, - beta2_pow_ptr, - learning_rate.template data(), - dev_ctx.template Alloc(moment1_out), - dev_ctx.template Alloc(moment2_out), - reinterpret_cast(dev_ctx.template Alloc(param_out)), - beta1_, - beta2_, - epsilon_, - param.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + if (!with_decay) { + coeff = static_cast(0.0); } + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + float* new_lr = RAII_GUARD.alloc_l3_or_gm(learning_rate.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(new_lr); + int r = 0; + r = xpu::scale(dev_ctx.x_context(), + learning_rate.template data(), + new_lr, + learning_rate.numel(), + false, + lr_ratio, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + + r = xpu::adamw( + dev_ctx.x_context(), + reinterpret_cast(grad.template data()), + moment1.template data(), + moment2.template data(), + reinterpret_cast(param.template data()), + beta1_pow_ptr, + beta2_pow_ptr, + new_lr, + dev_ctx.template Alloc(moment1_out), + dev_ctx.template Alloc(moment2_out), + reinterpret_cast(dev_ctx.template Alloc(param_out)), + beta1_, + beta2_, + epsilon_, + coeff, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); if (!use_global_beta_pow) { - // update in cpu and then copy to xpu + // update in cpu if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { const float* beta1_pow_p = beta1_pow.template data(); dev_ctx.template HostAlloc(beta1_pow_out)[0] = @@ -136,7 +132,7 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template HostAlloc(beta2_pow_out)[0] = beta2_ * beta2_pow_p[0]; xpu_wait(dev_ctx.x_context()->xpu_stream); - } else { + } else { // update in xpu float* beta1_pow_out_p = dev_ctx.template Alloc(beta1_pow_out); float* beta2_pow_out_p = dev_ctx.template Alloc(beta2_pow_out); int r = xpu::scale(dev_ctx.x_context(), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_adamw_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_adamw_op_xpu.py index 5b628498f00..57b57adce3b 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_adamw_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_adamw_op_xpu.py @@ -17,6 +17,7 @@ import sys sys.path.append("..") import unittest +from functools import partial import numpy as np from op_test_xpu import XPUOpTest @@ -301,6 +302,345 @@ class XPUTestAdamwOp2(XPUOpTestWrapper): adam.step() adam.clear_gradients() + class TestAdamWOpLayerwiseLR(TestAdamWOp): + def setUp(self): + np.random.seed(2022) + paddle.seed(2022) + + def test_adamw_op_dygraph(self): + paddle.disable_static() + linear1 = paddle.nn.Linear( + 13, 8, bias_attr=paddle.nn.initializer.Constant(value=1.0) + ) + linear2 = paddle.nn.Linear( + 8, 5, bias_attr=paddle.nn.initializer.Constant(value=1.0) + ) + + # fix the linear name, simple_lr_setting function will use the name + linear1.weight.name = "linear_1.w_0" + linear1.bias.name = "linear_1.b_0" + linear2.weight.name = "linear_2.w_0" + linear2.bias.name = "linear_2.b_0" + + fc1_w = np.array(linear1.weight) + fc1_w_mon1 = np.zeros_like(fc1_w) + fc1_w_mon2 = np.zeros_like(fc1_w) + fc1_b = np.array(linear1.bias) + fc1_b_mon1 = np.zeros_like(fc1_b) + fc1_b_mon2 = np.zeros_like(fc1_b) + + fc2_w = np.array(linear2.weight) + fc2_w_mon1 = np.zeros_like(fc2_w) + fc2_w_mon2 = np.zeros_like(fc2_w) + fc2_b = np.array(linear2.bias) + fc2_b_mon1 = np.zeros_like(fc2_b) + fc2_b_mon2 = np.zeros_like(fc2_b) + + simple_lr_fun = partial( + simple_lr_setting, decay_rate=0.8, n_layers=2 + ) + learning_rate = 0.001 + weight_decay = 0.01 + beta1 = 0.9 + beta2 = 0.999 + + opt = paddle.optimizer.AdamW( + learning_rate=learning_rate, + parameters=[ + {'params': linear1.parameters()}, + { + 'params': linear2.parameters(), + }, + ], + apply_decay_param_fun=lambda name: True, + weight_decay=weight_decay, + lr_ratio=simple_lr_fun, + ) + + def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + np_inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1**t]).astype("float32"), + 'Beta2Pow': np.array([beta2**t]).astype("float32"), + } + + np_attrs = { + 'epsilon': 1e-8, + 'beta1': beta1, + 'beta2': beta2, + "lr_ratio": lr_ratio, + "coeff": weight_decay, + "with_decay": True, + } + param_out, moment1_out, moment2_out = adamw_step( + np_inputs, np_attrs + ) + return param_out, moment1_out, moment2_out + + for i in range(5): + a = paddle.to_tensor( + np.random.uniform(-1, 1, (2, 13)).astype("float32") + ) + a1 = linear1(a) + out = linear2(a1) + out = paddle.mean(out) + out.backward() + + fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( + fc1_w, + np.array(linear1.weight.grad), + fc1_w_mon1, + fc1_w_mon2, + simple_lr_fun(linear1.weight), + i + 1, + ) + fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( + fc1_b, + np.array(linear1.bias.grad), + fc1_b_mon1, + fc1_b_mon2, + simple_lr_fun(linear1.bias), + i + 1, + ) + fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( + fc2_w, + np.array(linear2.weight.grad), + fc2_w_mon1, + fc2_w_mon2, + simple_lr_fun(linear2.weight), + i + 1, + ) + fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( + fc2_b, + np.array(linear2.bias.grad), + fc2_b_mon1, + fc2_b_mon2, + simple_lr_fun(linear2.bias), + i + 1, + ) + + opt.step() + opt.clear_gradients() + + np.testing.assert_allclose( + linear1.weight.numpy(), fc1_w, rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + linear1.bias.numpy(), fc1_b, rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + linear2.weight.numpy(), fc2_w, rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + linear2.bias.numpy(), fc2_b, rtol=1e-5, atol=1e-5 + ) + + def test_adamw_op(self): + paddle.enable_static() + place = fluid.XPUPlace(0) + + learning_rate = 0.0001 + beta1 = 0.85 + beta2 = 0.95 + weight_decay = 0.01 + epsilon = 1e-8 + + train_prog = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(train_prog, startup): + with fluid.unique_name.guard(): + x = fluid.data(name='x', shape=[None, 10], dtype='float32') + y = fluid.data(name='y', shape=[None, 1], dtype='float32') + + weight_attr1 = paddle.framework.ParamAttr( + name="linear_0.w_0" + ) + bias_attr1 = paddle.framework.ParamAttr( + name="linear_0.b_0", + initializer=paddle.nn.initializer.Constant(value=1.0), + ) + weight_attr2 = paddle.framework.ParamAttr( + name="linear_1.w_0" + ) + bias_attr2 = paddle.framework.ParamAttr( + name="linear_1.b_0", + initializer=paddle.nn.initializer.Constant(value=1.0), + ) + linear1 = paddle.nn.Linear( + 10, 32, weight_attr=weight_attr1, bias_attr=bias_attr1 + ) + linear2 = paddle.nn.Linear( + 32, 1, weight_attr=weight_attr2, bias_attr=bias_attr2 + ) + + out = linear1(x) + out = linear2(out) + + fc1_w_mon1 = np.zeros((linear1.weight.shape)).astype( + "float32" + ) + fc1_w_mon2 = np.zeros((linear1.weight.shape)).astype( + "float32" + ) + fc1_b_mon1 = np.zeros((linear1.bias.shape)).astype( + "float32" + ) + fc1_b_mon2 = np.zeros((linear1.bias.shape)).astype( + "float32" + ) + fc2_w_mon1 = np.zeros((linear2.weight.shape)).astype( + "float32" + ) + fc2_w_mon2 = np.zeros((linear2.weight.shape)).astype( + "float32" + ) + fc2_b_mon1 = np.zeros((linear2.bias.shape)).astype( + "float32" + ) + fc2_b_mon2 = np.zeros((linear2.bias.shape)).astype( + "float32" + ) + + cost = paddle.nn.functional.square_error_cost( + input=out, label=y + ) + avg_cost = paddle.mean(cost) + + simple_lr_fun = partial( + simple_lr_setting, decay_rate=0.8, n_layers=2 + ) + + opt = paddle.optimizer.AdamW( + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + epsilon=epsilon, + lr_ratio=simple_lr_fun, + ) + opt.minimize(avg_cost) + + def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + np_inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1**t]).astype("float32"), + 'Beta2Pow': np.array([beta2**t]).astype("float32"), + } + + np_attrs = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + "lr_ratio": lr_ratio, + "coeff": weight_decay, + "with_decay": True, + } + param_out, moment1_out, moment2_out = adamw_step( + np_inputs, np_attrs + ) + return param_out, moment1_out, moment2_out + + fetch_list1 = [ + "linear_0.w_0", + "linear_0.b_0", + "linear_1.w_0", + "linear_1.b_0", + ] + fetch_list2 = [ + "linear_0.w_0", + "linear_0.w_0@GRAD", + "linear_0.b_0", + "linear_0.b_0@GRAD", + "linear_1.w_0", + "linear_1.w_0@GRAD", + "linear_1.b_0", + "linear_1.b_0@GRAD", + ] + + exe = fluid.Executor(place) + exe.run(startup) + test_prog = train_prog.clone(for_test=True) + + for i in range(5): + inputs = np.random.random(size=[8, 10]).astype('float32') + outputs = np.random.random(size=[8, 1]).astype('float32') + + param = exe.run( + test_prog, + feed={"x": inputs, "y": outputs}, + fetch_list=fetch_list1, + ) + params_and_gras = exe.run( + train_prog, + feed={"x": inputs, "y": outputs}, + fetch_list=fetch_list2, + ) + + fc1_w = param[0] + fc1_w_grad = params_and_gras[1] + fc1_b = param[1] + fc1_b_grad = params_and_gras[3] + fc2_w = param[2] + fc2_w_grad = params_and_gras[5] + fc2_b = param[3] + fc2_b_grad = params_and_gras[7] + + fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( + fc1_w, + fc1_w_grad, + fc1_w_mon1, + fc1_w_mon2, + simple_lr_fun(linear1.weight), + i + 1, + ) + fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( + fc1_b, + fc1_b_grad, + fc1_b_mon1, + fc1_b_mon2, + simple_lr_fun(linear1.bias), + i + 1, + ) + fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( + fc2_w, + fc2_w_grad, + fc2_w_mon1, + fc2_w_mon2, + simple_lr_fun(linear2.weight), + i + 1, + ) + fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( + fc2_b, + fc2_b_grad, + fc2_b_mon1, + fc2_b_mon2, + simple_lr_fun(linear2.bias), + i + 1, + ) + + np.testing.assert_allclose( + params_and_gras[0], fc1_w, rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + params_and_gras[2], fc1_b, rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + params_and_gras[4], fc2_w, rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + params_and_gras[6], fc2_b, rtol=1e-5, atol=1e-5 + ) + + paddle.disable_static() + support_types = get_xpu_op_support_types('adamw') for stype in support_types: diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index a4d304b451e..a5cb7798353 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -178,9 +178,12 @@ class AdamW(Optimizer): raise TypeError("weight_decay should be float or Tensor.") if lr_ratio is not None: assert isinstance(lr_ratio, Callable) - if not core.is_compiled_with_cuda(): + if ( + not core.is_compiled_with_cuda() + and not core.is_compiled_with_xpu() + ): raise NotImplementedError( - "'lr_ratio' is unimplemented in CPU, XPU and NPU" + "'lr_ratio' is unimplemented in CPU, and NPU" ) if parameters is not None: -- GitLab