diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index c2ccd48d4ca1e51161a1e15b7f9c978a0da51bff..14b9bc5aae0bce7d73f79c09858054d655c8353b 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -187,18 +187,14 @@ void TensorAdd(const VarType& src, VarType* dst) { auto data_type = framework::TransToProtoVarType(src_tensor.dtype()); auto place = src_tensor.place(); - PADDLE_ENFORCE_EQ(framework::TransToProtoVarType(dst_tensor->dtype()), - data_type, - platform::errors::PreconditionNotMet( - "The data type of source tensor and destination tensor " - "should be equal, Otherwise, the calculation results " - "will be incorrect.")); - // if src and dst are in different place, copy dst to src's place if (dst_tensor->place() != place) { paddle::framework::TensorCopySync(*dst_tensor, place, dst_tensor); } + // AddKernel already support inputs of different dtype. For AMP master_grad, + // the dtype of source tensor and destination tensor will be diferent. So the + // check requiring input dtypes to be the same have been removed. #define PADDLE_TENSOR_ADD(T, CONTEXT) \ if (data_type == framework::DataTypeTrait::DataType()) { \ auto cpu_ctx = static_cast( \ diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 8df301520ec50cc378be340013f804bf4d2dcb73..848fa1fe742e8b5f1d3ae399278c4cb3313f5639 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -47,6 +47,7 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" @@ -1246,6 +1247,37 @@ static PyObject* eager_api__add_backward_final_hook(PyObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* eager_api_set_master_grads(PyObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + // tensor_list is a list of model parameters. + auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); + for (auto& tensor : tensor_list) { + VLOG(6) << "set master_grad for tensor: " << tensor.name(); + PADDLE_ENFORCE_EQ( + egr::egr_utils_api::IsLeafTensor(tensor), + true, + paddle::platform::errors::Fatal("Only leaf Tensor can be set grad.")); + paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor); + PADDLE_ENFORCE_NE(grad, + nullptr, + paddle::platform::errors::Fatal( + "Detected NULL grad" + "Please check if you have manually cleared" + "the grad inside autograd_meta")); + auto dtype = (*grad).dtype(); + if ((*grad).initialized() && + (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16)) { + auto master_grad = + paddle::experimental::cast(*grad, phi::DataType::FLOAT32); + grad->set_impl(master_grad.impl()); + } + } + RETURN_PY_NONE + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef variable_functions[] = { // TODO(jiabin): Remove scale when we have final state tests {"scale", @@ -1314,6 +1346,11 @@ PyMethodDef variable_functions[] = { (PyCFunction)(void (*)(void))eager_api_reset_saved_tensors_hooks, METH_VARARGS | METH_KEYWORDS, NULL}, + /**amp functions**/ + {"set_master_grads", + (PyCFunction)(void (*)(void))eager_api_set_master_grads, + METH_VARARGS | METH_KEYWORDS, + NULL}, /**sparse functions**/ #if defined(PADDLE_WITH_CUDA) {"async_read", diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 7e20980be9567698a9e2fd7929fdc4cbdbdb8c83..33c7855d89724350e99a388efe534570f4961376 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -101,6 +101,23 @@ def amp_state(): return _g_amp_state_ +class AMPGlobalState: + def __init__(self): + self.model_parameters = [] + self.use_master_grad = False + self.already_register_final_backward_hook = False + + def __setattr__(self, name, val): + self.__dict__[name] = val + + +_amp_global_state = AMPGlobalState() + + +def amp_global_state(): + return _amp_global_state + + # NOTE(zhiqiu): similar as paddle.static.amp.fp16_lists.AutoMixedPrecisionLists._update_list # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. def _update_list( @@ -418,6 +435,21 @@ def amp_guard( amp_level = AMP_LEVEL.O0 amp_dtype = "float32" + # master_grad_hook will run at the end of backward. + # Since backward_final_hook will be cleared once they have been + # done, we should register the hook every step. + if ( + amp_global_state().use_master_grad + and not amp_global_state().already_register_final_backward_hook + ): + + def master_grad_hook(): + core.eager.set_master_grads(amp_global_state().model_parameters) + amp_global_state().already_register_final_backward_hook = False + + core.eager._add_backward_final_hook(master_grad_hook) + amp_global_state().already_register_final_backward_hook = True + if tracer: # enable auto_cast original_amp_level = tracer._amp_level @@ -486,6 +518,7 @@ def amp_decorate( dtype='float16', master_weight=None, save_dtype=None, + master_grad=False, ): """ Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. @@ -599,6 +632,14 @@ def amp_decorate( for opt in optimizers: _set_multi_precision(opt, use_multi_precision) + # support master_grad + if master_grad: + amp_global_state().use_master_grad = True + for idx in range(len(models)): + amp_global_state().model_parameters.extend( + models[idx].parameters() + ) + if save_dtype is not None: if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']): raise ValueError( @@ -696,6 +737,7 @@ def decorate( dtype='float16', master_weight=None, save_dtype=None, + master_grad=False, ): """ Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. @@ -712,6 +754,8 @@ def decorate( master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. + master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight + gradients will be FP32 dtype after the backpropagation. Default is False. Examples: @@ -761,5 +805,5 @@ def decorate( print(output.dtype) # FP16 """ return amp_decorate( - models, optimizers, level, dtype, master_weight, save_dtype + models, optimizers, level, dtype, master_weight, save_dtype, master_grad ) diff --git a/test/amp/test_amp_master_grad.py b/test/amp/test_amp_master_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..6b5aebf35771e637108898e55d2d93ed2896caa3 --- /dev/null +++ b/test/amp/test_amp_master_grad.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid import core + + +class SimpleNet(paddle.nn.Layer): + def __init__(self, input_size, output_size): + super().__init__() + self.linear = paddle.nn.Linear(input_size, output_size) + + def forward(self, x): + x = self.linear(x) + return x + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the float16", +) +class TestMasterGrad(unittest.TestCase): + def check_results( + self, fp32_grads, op_list, total_steps, accumulate_batchs_num + ): + for grad in fp32_grads: + self.assertEqual(grad.dtype, paddle.float32) + # fp16 calls + self.assertEqual(int(op_list['matmul_v2'].split(',')[0]), total_steps) + self.assertEqual( + int(op_list['adamw_'].split(',')[0]), + 2 * (total_steps / accumulate_batchs_num), + ) + self.assertEqual( + int(op_list['transfer_dtype'].split(',')[0]), + total_steps + total_steps * 2, + ) + + def run_dygraph(self, total_steps, accumulate_batchs_num): + model = SimpleNet(2, 4) + opt = paddle.optimizer.AdamW(parameters=model.parameters()) + model, opt = paddle.amp.decorate( + model, optimizers=opt, level='O2', master_grad=True + ) + scaler = paddle.amp.GradScaler() + + paddle.amp.debugging.enable_operator_stats_collection() + for i in range(total_steps): + x = np.random.random((2, 2)).astype('float32') + label = np.random.random((2, 4)).astype('float32') + + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.to_tensor(x)) + loss = paddle.nn.functional.l1_loss( + out, paddle.to_tensor(label) + ) + scaled = scaler.scale(loss) + scaled.backward() + fp32_grads = [model.linear.weight.grad, model.linear.bias.grad] + if (i + 1) % accumulate_batchs_num == 0: + scaler.step(opt) + scaler.update() + opt.clear_grad() + paddle.amp.debugging.disable_operator_stats_collection() + op_list = paddle.fluid.core.get_low_precision_op_list() + return fp32_grads, op_list + + def test_master_grad(self): + total_steps = 4 + accumulate_batchs_num = 2 + fp32_grads, op_list = self.run_dygraph( + total_steps, accumulate_batchs_num + ) + self.check_results( + fp32_grads, op_list, total_steps, accumulate_batchs_num + ) + + +if __name__ == '__main__': + unittest.main()