未验证 提交 290aa368 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

add fp32 grad plus fp16 param in adamw (#51141)

* add fp32 grad plus fp16 param in adamw

* add python UT

* fix test case

* in test_adamw_op py file, force the moment2 value LE 0

* add a compare option

* remove bf16 fused adam kernel case
上级 c07c7712
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
namespace phi { namespace phi {
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamWKernelREG(MT beta1, __global__ void AdamWKernelREG(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1, ...@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1, ...@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
} }
} }
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamWKernelMEM(MT beta1, __global__ void AdamWKernelMEM(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1, ...@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* master_param_outs) { DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type; using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
MPDType coeff_ = static_cast<MPDType>(coeff); MPDType coeff_ = static_cast<MPDType>(coeff);
...@@ -235,7 +237,31 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -235,7 +237,31 @@ void AdamwDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG // Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32)
AdamWKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_, beta1_,
beta2_, beta2_,
epsilon_, epsilon_,
...@@ -262,7 +288,29 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -262,7 +288,29 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0]; beta2_ * beta2_pow.data<MPDType>()[0];
} }
} else { } else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32)
AdamWKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_, beta1_,
beta2_, beta2_,
epsilon_, epsilon_,
......
...@@ -492,6 +492,7 @@ PD_REGISTER_KERNEL(fused_adam, ...@@ -492,6 +492,7 @@ PD_REGISTER_KERNEL(fused_adam,
ALL_LAYOUT, ALL_LAYOUT,
phi::FusedAdamKernel, phi::FusedAdamKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double) { double) {
// Skip beta1_pow, beta2_pow, skip_update data transform // Skip beta1_pow, beta2_pow, skip_update data transform
......
...@@ -320,6 +320,145 @@ class TestAdamWOpGroup(TestAdamWOp): ...@@ -320,6 +320,145 @@ class TestAdamWOpGroup(TestAdamWOp):
adam.clear_gradients() adam.clear_gradients()
class TestAdamWOpMultiPrecisonWithMainGrad(unittest.TestCase):
def _test_adamw_op_dygraph_place_amp_with_maingrad(
self, place, shape, use_main_grad
):
paddle.disable_static()
paddle.seed(10)
paddle.set_device(place)
found_inf = None
_weight_decay = 0.1
with_decay = True
_lazy_mode = False
find_master = True
_epsilon = 1e-8
_beta1 = 0.9
_beta2 = 0.99
lr_ratio_ = 1.0
lr_rate = 1e-8
param = paddle.randn(shape).astype(paddle.bfloat16)
master_weight = param.astype(paddle.float32)
grad = paddle.randn(shape).astype(paddle.bfloat16)
main_grad = grad.astype(paddle.float32)
moment1 = paddle.randn(shape).astype(paddle.float32)
moment2 = paddle.randn(shape).astype(paddle.float32).abs()
lr = paddle.zeros([1]).astype(paddle.float32)
lr[0] = lr_rate
beta1_pow_acc = paddle.ones([1]).astype(paddle.float32)
beta1_pow_acc[0] = _beta1**10
beta2_pow_acc = paddle.ones([1]).astype(paddle.float32)
beta2_pow_acc[0] = _beta2**10
ref_param = param.astype(paddle.float32)
ref_beta1_pow_acc = beta1_pow_acc.astype(paddle.float32)
ref_beta2_pow_acc = beta2_pow_acc.astype(paddle.float32)
ref_moment_1 = moment1.astype(paddle.float32)
ref_moment_2 = moment2.astype(paddle.float32)
# reference code
_, _, _, _, _, _ = paddle._C_ops.adamw_(
ref_param,
main_grad,
lr,
ref_moment_1,
ref_moment_2,
ref_beta1_pow_acc,
ref_beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
_epsilon,
lr_ratio_,
_weight_decay,
with_decay,
_lazy_mode,
1000,
False,
False,
)
if use_main_grad:
_, _, _, _, _, _ = paddle._C_ops.adamw_(
param,
main_grad,
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
_epsilon,
lr_ratio_,
_weight_decay,
with_decay,
_lazy_mode,
1000,
find_master,
False,
)
np.testing.assert_allclose(
param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2
)
np.testing.assert_allclose(
master_weight.numpy(), ref_param.numpy(), rtol=1e-6
)
else:
_, _, _, _, _, _ = paddle._C_ops.adamw_(
param,
grad,
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
_epsilon,
lr_ratio_,
_weight_decay,
with_decay,
_lazy_mode,
1000,
find_master,
False,
)
np.testing.assert_allclose(
param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2
)
np.testing.assert_allclose(
master_weight.numpy(), ref_param.numpy(), rtol=1e-6
)
def _get_places(self):
places = []
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def test_main(self):
for _ in range(10):
shape = paddle.randint(1, 1024, [2])
for place in self._get_places():
use_main_grad_list = [True, False]
for use_main_grad in use_main_grad_list:
self._test_adamw_op_dygraph_place_amp_with_maingrad(
place, shape, use_main_grad
)
class TestAdamWOpMultiPrecison(unittest.TestCase): class TestAdamWOpMultiPrecison(unittest.TestCase):
def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False): def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False):
paddle.disable_static() paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册