From 290aa368e3b3be0884cbe3bffb56b951406f8f87 Mon Sep 17 00:00:00 2001 From: Shaojie WANG Date: Thu, 16 Mar 2023 02:20:06 -0700 Subject: [PATCH] add fp32 grad plus fp16 param in adamw (#51141) * add fp32 grad plus fp16 param in adamw * add python UT * fix test case * in test_adamw_op py file, force the moment2 value LE 0 * add a compare option * remove bf16 fused adam kernel case --- paddle/phi/kernels/gpu/adamw_kernel.cu | 132 +++++++++++------ paddle/phi/kernels/gpu/fused_adam_kernel.cu | 1 + .../fluid/tests/unittests/test_adamw_op.py | 139 ++++++++++++++++++ 3 files changed, 230 insertions(+), 42 deletions(-) diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 8a27df71956..448153ed219 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -29,7 +29,7 @@ #include "paddle/phi/kernels/funcs/selected_rows_functor.h" namespace phi { -template +template __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, @@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1, const MT* moment2, MT* moment2_out, const MT* lr_, - const T* grad, + const TG* grad, const T* param, T* param_out, const MT* master_param, @@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1, } } -template +template __global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, @@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1, const MT* moment2, MT* moment2_out, const MT* lr_, - const T* grad, + const TG* grad, const T* param, T* param_out, const MT* master_param, @@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx, DenseTensor* master_param_outs) { using MPDType = typename phi::dtype::MPTypeTrait::Type; + const auto grad_type = grad.dtype(); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; MPDType coeff_ = static_cast(coeff); @@ -235,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx, if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { // Compute with betapow in REG - AdamWKernelREG<<>>( - beta1_, - beta2_, - epsilon_, - coeff_, - lr_ratio_, - *beta1_pow.data(), - *beta2_pow.data(), - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), - learning_rate.data(), - grad.data(), - param.data(), - dev_ctx.template Alloc(param_out), - master_in_data, - master_out_data, - param.numel()); + if (grad_type == phi::DataType::FLOAT32) + AdamWKernelREG + <<>>( + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + *beta1_pow.data(), + *beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); + + else + + AdamWKernelREG<<>>( + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + *beta1_pow.data(), + *beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); if (!use_global_beta_pow) { // Cpu update dev_ctx.template HostAlloc(beta1_pow_out)[0] = @@ -262,25 +288,47 @@ void AdamwDenseKernel(const Context& dev_ctx, beta2_ * beta2_pow.data()[0]; } } else { - AdamWKernelMEM<<>>( - beta1_, - beta2_, - epsilon_, - coeff_, - lr_ratio_, - beta1_pow.data(), - beta2_pow.data(), - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), - learning_rate.data(), - grad.data(), - param.data(), - dev_ctx.template Alloc(param_out), - master_in_data, - master_out_data, - param.numel()); + if (grad_type == phi::DataType::FLOAT32) + AdamWKernelMEM + <<>>( + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + beta1_pow.data(), + beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); + else + AdamWKernelMEM<<>>( + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + beta1_pow.data(), + beta2_pow.data(), + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + param.data(), + dev_ctx.template Alloc(param_out), + master_in_data, + master_out_data, + param.numel()); if (!use_global_beta_pow) { // Update with gpu UpdateAdamWBetaPow<<<1, 1, 0, dev_ctx.stream()>>>( diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index 533ef6fd150..a87cdda5074 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -492,6 +492,7 @@ PD_REGISTER_KERNEL(fused_adam, ALL_LAYOUT, phi::FusedAdamKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double) { // Skip beta1_pow, beta2_pow, skip_update data transform diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index 9ab77d7c2a1..ead0a00ac11 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -320,6 +320,145 @@ class TestAdamWOpGroup(TestAdamWOp): adam.clear_gradients() +class TestAdamWOpMultiPrecisonWithMainGrad(unittest.TestCase): + def _test_adamw_op_dygraph_place_amp_with_maingrad( + self, place, shape, use_main_grad + ): + paddle.disable_static() + paddle.seed(10) + paddle.set_device(place) + + found_inf = None + + _weight_decay = 0.1 + with_decay = True + _lazy_mode = False + find_master = True + + _epsilon = 1e-8 + + _beta1 = 0.9 + _beta2 = 0.99 + lr_ratio_ = 1.0 + + lr_rate = 1e-8 + + param = paddle.randn(shape).astype(paddle.bfloat16) + master_weight = param.astype(paddle.float32) + grad = paddle.randn(shape).astype(paddle.bfloat16) + main_grad = grad.astype(paddle.float32) + moment1 = paddle.randn(shape).astype(paddle.float32) + moment2 = paddle.randn(shape).astype(paddle.float32).abs() + lr = paddle.zeros([1]).astype(paddle.float32) + lr[0] = lr_rate + beta1_pow_acc = paddle.ones([1]).astype(paddle.float32) + beta1_pow_acc[0] = _beta1**10 + beta2_pow_acc = paddle.ones([1]).astype(paddle.float32) + beta2_pow_acc[0] = _beta2**10 + + ref_param = param.astype(paddle.float32) + ref_beta1_pow_acc = beta1_pow_acc.astype(paddle.float32) + ref_beta2_pow_acc = beta2_pow_acc.astype(paddle.float32) + ref_moment_1 = moment1.astype(paddle.float32) + ref_moment_2 = moment2.astype(paddle.float32) + + # reference code + _, _, _, _, _, _ = paddle._C_ops.adamw_( + ref_param, + main_grad, + lr, + ref_moment_1, + ref_moment_2, + ref_beta1_pow_acc, + ref_beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + _epsilon, + lr_ratio_, + _weight_decay, + with_decay, + _lazy_mode, + 1000, + False, + False, + ) + + if use_main_grad: + _, _, _, _, _, _ = paddle._C_ops.adamw_( + param, + main_grad, + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + _epsilon, + lr_ratio_, + _weight_decay, + with_decay, + _lazy_mode, + 1000, + find_master, + False, + ) + np.testing.assert_allclose( + param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 + ) + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-6 + ) + else: + _, _, _, _, _, _ = paddle._C_ops.adamw_( + param, + grad, + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + _epsilon, + lr_ratio_, + _weight_decay, + with_decay, + _lazy_mode, + 1000, + find_master, + False, + ) + np.testing.assert_allclose( + param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 + ) + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-6 + ) + + def _get_places(self): + places = [] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + return places + + def test_main(self): + for _ in range(10): + shape = paddle.randint(1, 1024, [2]) + for place in self._get_places(): + use_main_grad_list = [True, False] + for use_main_grad in use_main_grad_list: + self._test_adamw_op_dygraph_place_amp_with_maingrad( + place, shape, use_main_grad + ) + + class TestAdamWOpMultiPrecison(unittest.TestCase): def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False): paddle.disable_static() -- GitLab