未验证 提交 290aa368 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

add fp32 grad plus fp16 param in adamw (#51141)

* add fp32 grad plus fp16 param in adamw

* add python UT

* fix test case

* in test_adamw_op py file, force the moment2 value LE 0

* add a compare option

* remove bf16 fused adam kernel case
上级 c07c7712
......@@ -29,7 +29,7 @@
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
namespace phi {
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamWKernelREG(MT beta1,
MT beta2,
MT epsilon,
......@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
}
}
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamWKernelMEM(MT beta1,
MT beta2,
MT epsilon,
......@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
MPDType coeff_ = static_cast<MPDType>(coeff);
......@@ -235,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32)
AdamWKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
......@@ -262,25 +288,47 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0];
}
} else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32)
AdamWKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
......
......@@ -492,6 +492,7 @@ PD_REGISTER_KERNEL(fused_adam,
ALL_LAYOUT,
phi::FusedAdamKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {
// Skip beta1_pow, beta2_pow, skip_update data transform
......
......@@ -320,6 +320,145 @@ class TestAdamWOpGroup(TestAdamWOp):
adam.clear_gradients()
class TestAdamWOpMultiPrecisonWithMainGrad(unittest.TestCase):
def _test_adamw_op_dygraph_place_amp_with_maingrad(
self, place, shape, use_main_grad
):
paddle.disable_static()
paddle.seed(10)
paddle.set_device(place)
found_inf = None
_weight_decay = 0.1
with_decay = True
_lazy_mode = False
find_master = True
_epsilon = 1e-8
_beta1 = 0.9
_beta2 = 0.99
lr_ratio_ = 1.0
lr_rate = 1e-8
param = paddle.randn(shape).astype(paddle.bfloat16)
master_weight = param.astype(paddle.float32)
grad = paddle.randn(shape).astype(paddle.bfloat16)
main_grad = grad.astype(paddle.float32)
moment1 = paddle.randn(shape).astype(paddle.float32)
moment2 = paddle.randn(shape).astype(paddle.float32).abs()
lr = paddle.zeros([1]).astype(paddle.float32)
lr[0] = lr_rate
beta1_pow_acc = paddle.ones([1]).astype(paddle.float32)
beta1_pow_acc[0] = _beta1**10
beta2_pow_acc = paddle.ones([1]).astype(paddle.float32)
beta2_pow_acc[0] = _beta2**10
ref_param = param.astype(paddle.float32)
ref_beta1_pow_acc = beta1_pow_acc.astype(paddle.float32)
ref_beta2_pow_acc = beta2_pow_acc.astype(paddle.float32)
ref_moment_1 = moment1.astype(paddle.float32)
ref_moment_2 = moment2.astype(paddle.float32)
# reference code
_, _, _, _, _, _ = paddle._C_ops.adamw_(
ref_param,
main_grad,
lr,
ref_moment_1,
ref_moment_2,
ref_beta1_pow_acc,
ref_beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
_epsilon,
lr_ratio_,
_weight_decay,
with_decay,
_lazy_mode,
1000,
False,
False,
)
if use_main_grad:
_, _, _, _, _, _ = paddle._C_ops.adamw_(
param,
main_grad,
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
_epsilon,
lr_ratio_,
_weight_decay,
with_decay,
_lazy_mode,
1000,
find_master,
False,
)
np.testing.assert_allclose(
param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2
)
np.testing.assert_allclose(
master_weight.numpy(), ref_param.numpy(), rtol=1e-6
)
else:
_, _, _, _, _, _ = paddle._C_ops.adamw_(
param,
grad,
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
_epsilon,
lr_ratio_,
_weight_decay,
with_decay,
_lazy_mode,
1000,
find_master,
False,
)
np.testing.assert_allclose(
param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2
)
np.testing.assert_allclose(
master_weight.numpy(), ref_param.numpy(), rtol=1e-6
)
def _get_places(self):
places = []
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def test_main(self):
for _ in range(10):
shape = paddle.randint(1, 1024, [2])
for place in self._get_places():
use_main_grad_list = [True, False]
for use_main_grad in use_main_grad_list:
self._test_adamw_op_dygraph_place_amp_with_maingrad(
place, shape, use_main_grad
)
class TestAdamWOpMultiPrecison(unittest.TestCase):
def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False):
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册