未验证 提交 4970dd65 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support master_grad for amp training (#52235)

* support set master_grad

* move register_hook to auto_cast

* update unittest

* fix fp16 test

* update for review comments
上级 6934ac79
...@@ -187,18 +187,14 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -187,18 +187,14 @@ void TensorAdd(const VarType& src, VarType* dst) {
auto data_type = framework::TransToProtoVarType(src_tensor.dtype()); auto data_type = framework::TransToProtoVarType(src_tensor.dtype());
auto place = src_tensor.place(); auto place = src_tensor.place();
PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()),
data_type,
platform::errors::PreconditionNotMet(
"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));
// if src and dst are in different place, copy dst to src's place // if src and dst are in different place, copy dst to src's place
if (dst_tensor->place() != place) { if (dst_tensor->place() != place) {
paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor); paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor);
} }
// AddKernel already support inputs of different dtype. For AMP master_grad,
// the dtype of source tensor and destination tensor will be diferent. So the
// check requiring input dtypes to be the same have been removed.
#define PADDLE_TENSOR_ADD(T, CONTEXT) \ #define PADDLE_TENSOR_ADD(T, CONTEXT) \
if (data_type == framework::DataTypeTrait<T>::DataType()) { \ if (data_type == framework::DataTypeTrait<T>::DataType()) { \
auto cpu_ctx = static_cast<CONTEXT*>( \ auto cpu_ctx = static_cast<CONTEXT*>( \
......
...@@ -47,6 +47,7 @@ typedef SSIZE_T ssize_t; ...@@ -47,6 +47,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
...@@ -1246,6 +1247,37 @@ static PyObject* eager_api__add_backward_final_hook(PyObject* self, ...@@ -1246,6 +1247,37 @@ static PyObject* eager_api__add_backward_final_hook(PyObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* eager_api_set_master_grads(PyObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
// tensor_list is a list of model parameters.
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
for (auto& tensor : tensor_list) {
VLOG(6) << "set master_grad for tensor: " << tensor.name();
PADDLE_ENFORCE_EQ(
egr::egr_utils_api::IsLeafTensor(tensor),
true,
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad."));
paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor);
PADDLE_ENFORCE_NE(grad,
nullptr,
paddle::platform::errors::Fatal(
"Detected NULL grad"
"Please check if you have manually cleared"
"the grad inside autograd_meta"));
auto dtype = (*grad).dtype();
if ((*grad).initialized() &&
(dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16)) {
auto master_grad =
paddle::experimental::cast(*grad, phi::DataType::FLOAT32);
grad->set_impl(master_grad.impl());
}
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyMethodDef variable_functions[] = { PyMethodDef variable_functions[] = {
// TODO(jiabin): Remove scale when we have final state tests // TODO(jiabin): Remove scale when we have final state tests
{"scale", {"scale",
...@@ -1314,6 +1346,11 @@ PyMethodDef variable_functions[] = { ...@@ -1314,6 +1346,11 @@ PyMethodDef variable_functions[] = {
(PyCFunction)(void (*)(void))eager_api_reset_saved_tensors_hooks, (PyCFunction)(void (*)(void))eager_api_reset_saved_tensors_hooks,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
/**amp functions**/
{"set_master_grads",
(PyCFunction)(void (*)(void))eager_api_set_master_grads,
METH_VARARGS | METH_KEYWORDS,
NULL},
/**sparse functions**/ /**sparse functions**/
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
{"async_read", {"async_read",
......
...@@ -101,6 +101,23 @@ def amp_state(): ...@@ -101,6 +101,23 @@ def amp_state():
return _g_amp_state_ return _g_amp_state_
class AMPGlobalState:
def __init__(self):
self.model_parameters = []
self.use_master_grad = False
self.already_register_final_backward_hook = False
def __setattr__(self, name, val):
self.__dict__[name] = val
_amp_global_state = AMPGlobalState()
def amp_global_state():
return _amp_global_state
# NOTE(zhiqiu): similar as paddle.static.amp.fp16_lists.AutoMixedPrecisionLists._update_list # NOTE(zhiqiu): similar as paddle.static.amp.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
def _update_list( def _update_list(
...@@ -418,6 +435,21 @@ def amp_guard( ...@@ -418,6 +435,21 @@ def amp_guard(
amp_level = AMP_LEVEL.O0 amp_level = AMP_LEVEL.O0
amp_dtype = "float32" amp_dtype = "float32"
# master_grad_hook will run at the end of backward.
# Since backward_final_hook will be cleared once they have been
# done, we should register the hook every step.
if (
amp_global_state().use_master_grad
and not amp_global_state().already_register_final_backward_hook
):
def master_grad_hook():
core.eager.set_master_grads(amp_global_state().model_parameters)
amp_global_state().already_register_final_backward_hook = False
core.eager._add_backward_final_hook(master_grad_hook)
amp_global_state().already_register_final_backward_hook = True
if tracer: if tracer:
# enable auto_cast # enable auto_cast
original_amp_level = tracer._amp_level original_amp_level = tracer._amp_level
...@@ -486,6 +518,7 @@ def amp_decorate( ...@@ -486,6 +518,7 @@ def amp_decorate(
dtype='float16', dtype='float16',
master_weight=None, master_weight=None,
save_dtype=None, save_dtype=None,
master_grad=False,
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
...@@ -599,6 +632,14 @@ def amp_decorate( ...@@ -599,6 +632,14 @@ def amp_decorate(
for opt in optimizers: for opt in optimizers:
_set_multi_precision(opt, use_multi_precision) _set_multi_precision(opt, use_multi_precision)
# support master_grad
if master_grad:
amp_global_state().use_master_grad = True
for idx in range(len(models)):
amp_global_state().model_parameters.extend(
models[idx].parameters()
)
if save_dtype is not None: if save_dtype is not None:
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']): if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
raise ValueError( raise ValueError(
...@@ -696,6 +737,7 @@ def decorate( ...@@ -696,6 +737,7 @@ def decorate(
dtype='float16', dtype='float16',
master_weight=None, master_weight=None,
save_dtype=None, save_dtype=None,
master_grad=False,
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
...@@ -712,6 +754,8 @@ def decorate( ...@@ -712,6 +754,8 @@ def decorate(
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight
gradients will be FP32 dtype after the backpropagation. Default is False.
Examples: Examples:
...@@ -761,5 +805,5 @@ def decorate( ...@@ -761,5 +805,5 @@ def decorate(
print(output.dtype) # FP16 print(output.dtype) # FP16
""" """
return amp_decorate( return amp_decorate(
models, optimizers, level, dtype, master_weight, save_dtype models, optimizers, level, dtype, master_weight, save_dtype, master_grad
) )
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
class SimpleNet(paddle.nn.Layer):
def __init__(self, input_size, output_size):
super().__init__()
self.linear = paddle.nn.Linear(input_size, output_size)
def forward(self, x):
x = self.linear(x)
return x
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_float16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the float16",
)
class TestMasterGrad(unittest.TestCase):
def check_results(
self, fp32_grads, op_list, total_steps, accumulate_batchs_num
):
for grad in fp32_grads:
self.assertEqual(grad.dtype, paddle.float32)
# fp16 calls
self.assertEqual(int(op_list['matmul_v2'].split(',')[0]), total_steps)
self.assertEqual(
int(op_list['adamw_'].split(',')[0]),
2 * (total_steps / accumulate_batchs_num),
)
self.assertEqual(
int(op_list['transfer_dtype'].split(',')[0]),
total_steps + total_steps * 2,
)
def run_dygraph(self, total_steps, accumulate_batchs_num):
model = SimpleNet(2, 4)
opt = paddle.optimizer.AdamW(parameters=model.parameters())
model, opt = paddle.amp.decorate(
model, optimizers=opt, level='O2', master_grad=True
)
scaler = paddle.amp.GradScaler()
paddle.amp.debugging.enable_operator_stats_collection()
for i in range(total_steps):
x = np.random.random((2, 2)).astype('float32')
label = np.random.random((2, 4)).astype('float32')
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.to_tensor(x))
loss = paddle.nn.functional.l1_loss(
out, paddle.to_tensor(label)
)
scaled = scaler.scale(loss)
scaled.backward()
fp32_grads = [model.linear.weight.grad, model.linear.bias.grad]
if (i + 1) % accumulate_batchs_num == 0:
scaler.step(opt)
scaler.update()
opt.clear_grad()
paddle.amp.debugging.disable_operator_stats_collection()
op_list = paddle.fluid.core.get_low_precision_op_list()
return fp32_grads, op_list
def test_master_grad(self):
total_steps = 4
accumulate_batchs_num = 2
fp32_grads, op_list = self.run_dygraph(
total_steps, accumulate_batchs_num
)
self.check_results(
fp32_grads, op_list, total_steps, accumulate_batchs_num
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册