未验证 提交 703a64a3 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support master_grad for adam and momentum (#54240)

* support master_grad for adam and momentum

Co-authored-by: zhangting_2017@163.com <zhangting2020>
上级 cdbf62f8
......@@ -1254,10 +1254,9 @@ static PyObject* eager_api_set_master_grads(PyObject* self,
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
for (auto& tensor : tensor_list) {
VLOG(6) << "set master_grad for tensor: " << tensor.name();
PADDLE_ENFORCE_EQ(
egr::egr_utils_api::IsLeafTensor(tensor),
true,
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad."));
if (!egr::egr_utils_api::IsLeafTensor(tensor)) {
continue;
}
paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor);
PADDLE_ENFORCE_NE(grad,
nullptr,
......@@ -1265,13 +1264,13 @@ static PyObject* eager_api_set_master_grads(PyObject* self,
"Detected NULL grad"
"Please check if you have manually cleared"
"the grad inside autograd_meta"));
auto dtype = (*grad).dtype();
if ((*grad).initialized() &&
(dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16)) {
if ((*grad).initialized() && ((*grad).dtype() == phi::DataType::FLOAT16 ||
(*grad).dtype() == phi::DataType::BFLOAT16)) {
auto master_grad =
paddle::experimental::cast(*grad, phi::DataType::FLOAT32);
grad->set_impl(master_grad.impl());
}
VLOG(6) << "finish setting master_grad for tensor: " << tensor.name();
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
......
......@@ -30,7 +30,7 @@
namespace phi {
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamKernelREG(MT beta1,
MT beta2,
MT epsilon,
......@@ -41,7 +41,7 @@ __global__ void AdamKernelREG(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -73,7 +73,7 @@ __global__ void AdamKernelREG(MT beta1,
}
}
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamKernelMEM(MT beta1,
MT beta2,
MT epsilon,
......@@ -84,7 +84,7 @@ __global__ void AdamKernelMEM(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -152,6 +152,7 @@ void AdamDenseKernel(const Context& dev_ctx,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
......@@ -212,7 +213,27 @@ void AdamDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
} else {
AdamKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
......@@ -229,6 +250,7 @@ void AdamDenseKernel(const Context& dev_ctx,
master_in_data,
master_out_data,
param.numel());
}
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
......@@ -237,7 +259,27 @@ void AdamDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0];
}
} else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
} else {
AdamKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
......@@ -254,6 +296,7 @@ void AdamDenseKernel(const Context& dev_ctx,
master_in_data,
master_out_data,
param.numel());
}
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
......@@ -308,10 +351,31 @@ void MergedAdamKernel(
int threads = 512;
int blocks = (param[idx]->numel() + threads - 1) / threads;
const auto grad_type = grad[idx]->dtype();
if (beta1_pow[idx]->place() == CPUPlace() &&
beta2_pow[idx]->place() == CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<float>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
} else {
AdamKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
......@@ -328,6 +392,7 @@ void MergedAdamKernel(
master_in_data,
master_out_data,
param[idx]->numel());
}
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out[idx])[0] =
......@@ -336,7 +401,27 @@ void MergedAdamKernel(
beta2_ * beta2_pow[idx]->data<MPDType>()[0];
}
} else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
if (grad_type == phi::DataType::FLOAT32) {
AdamKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<float>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
} else {
AdamKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
......@@ -353,6 +438,7 @@ void MergedAdamKernel(
master_in_data,
master_out_data,
param[idx]->numel());
}
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
......
......@@ -300,8 +300,25 @@ void MergedMomentumInnerCompute(
} else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
phi::funcs::ForRange<Context> for_range(
static_cast<const Context &>(ctx), params[idx]->numel());
const auto grad_type = grads[idx]->dtype();
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
if (grad_type == phi::DataType::FLOAT32) { \
DenseMomentumFunctor<T, float, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
grads[idx]->data<float>(), \
velocitys[idx]->data<MT>(), \
lr_temp->data<MPType>(), \
master_in_data, \
static_cast<MT>(mu), \
static_cast<MT>(rescale_grad), \
params[idx]->numel(), \
regularization_coeff, \
params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor); \
} else { \
DenseMomentumFunctor<T, T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
grads[idx]->data<T>(), \
velocitys[idx]->data<MT>(), \
......@@ -314,7 +331,9 @@ void MergedMomentumInnerCompute(
params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor);
for_range(functor); \
}
if (use_nesterov) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
......
......@@ -104,6 +104,7 @@ class CPUDenseMomentumFunctor {
};
template <typename T,
typename TG,
typename MT,
RegularizationType kRegType,
typename UpdateMethod>
......@@ -112,11 +113,11 @@ class DenseMomentumFunctor;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
template <typename T, typename TG, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, TG, MT, kRegType, UseNesterov> {
private:
const T* param_;
const T* grad_;
const TG* grad_;
const MT* velocity_;
const MultiPrecisionType<MT>* lr_;
const MT* master_param_;
......@@ -130,7 +131,7 @@ class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
public:
DenseMomentumFunctor(const T* param,
const T* grad,
const TG* grad,
const MT* velocity,
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param,
......@@ -176,11 +177,11 @@ class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
}
};
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
template <typename T, typename TG, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, TG, MT, kRegType, NoNesterov> {
private:
const T* param_;
const T* grad_;
const TG* grad_;
const MT* velocity_;
const MultiPrecisionType<MT>* lr_;
const MT* master_param_;
......@@ -194,7 +195,7 @@ class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
public:
DenseMomentumFunctor(const T* param,
const T* grad,
const TG* grad,
const MT* velocity,
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param,
......@@ -459,8 +460,25 @@ void MomentumDenseImpl(const Context& ctx,
velocity_out);
} else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
funcs::ForRange<Context> for_range(ctx, param.numel());
const auto grad_type = grad.dtype();
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
if (grad_type == phi::DataType::FLOAT32) { \
DenseMomentumFunctor<T, float, MT, __reg_type, __nesterov> functor( \
param.data<T>(), \
grad.data<float>(), \
velocity.data<MT>(), \
learning_rate.data<MultiPrecisionType<T>>(), \
master_in_data, \
mu, \
rescale_grad, \
param.numel(), \
regularization_coeff, \
ctx.template Alloc<T>(param_out), \
ctx.template Alloc<MT>(velocity_out), \
master_out_data); \
for_range(functor); \
} else { \
DenseMomentumFunctor<T, T, MT, __reg_type, __nesterov> functor( \
param.data<T>(), \
grad.data<T>(), \
velocity.data<MT>(), \
......@@ -473,7 +491,8 @@ void MomentumDenseImpl(const Context& ctx,
ctx.template Alloc<T>(param_out), \
ctx.template Alloc<MT>(velocity_out), \
master_out_data); \
for_range(functor);
for_range(functor); \
}
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
......
......@@ -1265,6 +1265,23 @@ class Optimizer:
):
return grad
regularization_term = None
# when master_grad is true in amp training, grad will be fp32, but param maybe fp16.
# we get master weight when master_grad is true to avoid type mismatch error.
def get_target_param(param, grad):
target_param = param
if param.dtype != grad.dtype:
find_master = (
self._multi_precision
and self._is_dtype_fp16_or_bf16(param.dtype)
)
if find_master and len(self._master_weights) != 0:
target_param = self._master_weights[param.name]
else:
target_param = param.astype(grad.dtype)
return target_param
param = get_target_param(param, grad)
if hasattr(param, 'regularizer') and param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad, grad.block)
......
......@@ -44,7 +44,7 @@ class TestMasterGrad(unittest.TestCase):
# fp16 calls
self.assertEqual(int(op_list['matmul_v2'].split(',')[0]), total_steps)
self.assertEqual(
int(op_list['adamw_'].split(',')[0]),
int(op_list['adam_'].split(',')[0]),
2 * (total_steps / accumulate_batchs_num),
)
self.assertEqual(
......@@ -52,14 +52,11 @@ class TestMasterGrad(unittest.TestCase):
total_steps + total_steps * 2,
)
def run_dygraph(self, total_steps, accumulate_batchs_num):
model = SimpleNet(2, 4)
opt = paddle.optimizer.AdamW(parameters=model.parameters())
def run_dygraph(self, total_steps, accumulate_batchs_num, model, optimizer):
model, opt = paddle.amp.decorate(
model, optimizers=opt, level='O2', master_grad=True
model, optimizers=optimizer, level='O2', master_grad=True
)
scaler = paddle.amp.GradScaler()
paddle.amp.debugging.enable_operator_stats_collection()
for i in range(total_steps):
x = np.random.random((2, 2)).astype('float32')
......@@ -81,16 +78,32 @@ class TestMasterGrad(unittest.TestCase):
op_list = paddle.fluid.core.get_low_precision_op_list()
return fp32_grads, op_list
def test_master_grad(self):
def test_adam_master_grad(self):
total_steps = 4
accumulate_batchs_num = 2
model = SimpleNet(2, 4)
opt = paddle.optimizer.Adam(parameters=model.parameters())
fp32_grads, op_list = self.run_dygraph(
total_steps, accumulate_batchs_num
total_steps, accumulate_batchs_num, model, opt
)
self.check_results(
fp32_grads, op_list, total_steps, accumulate_batchs_num
)
def test_momentum_master_grad(self):
total_steps = 4
accumulate_batchs_num = 1
model = SimpleNet(2, 4)
L1Decay = paddle.regularizer.L1Decay(0.0001)
opt = paddle.optimizer.Momentum(
parameters=model.parameters(), weight_decay=L1Decay
)
fp32_grads, op_list = self.run_dygraph(
total_steps, accumulate_batchs_num, model, opt
)
for grad in fp32_grads:
self.assertEqual(grad.dtype, paddle.float32)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册