diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index f1756bc02055c20d43ea053cdef3169f35f789c0..f162d622f5fedc8bb86aadfa07b4942c031159dd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -167,11 +167,12 @@ class FusedCommBuffer: self._acc_steps = acc_steps self._comm_group = comm_group - use_main_grad = hasattr(self._params[0], "main_grad") + self.use_main_grad = hasattr(self._params[0], "main_grad") self._task = None self._params_step_dict = {} self._params_checked_in = 0 + self._params_to_addr = {} self._act = act if self._act == HOOK_ACTION.ALL_REDUCE: @@ -186,7 +187,20 @@ class FusedCommBuffer: self._init_step_dict() - self.grad_storage = flatten_dense_tensors(self._params, use_main_grad) + self.grad_storage = flatten_dense_tensors( + self._params, self.use_main_grad + ) + + self._record_addr() + + def _record_addr(self): + for param in self._params: + addr = ( + param.main_grad.data_ptr() + if self.use_main_grad + else param.grad.data_ptr() + ) + self._params_to_addr[param.name] = addr def _init_step_dict(self): for p in self._params: @@ -206,6 +220,18 @@ class FusedCommBuffer: def add_grad(self, param): assert param.name in self._params_step_dict + current_ptr = ( + param.main_grad.data_ptr() + if self.use_main_grad + else param.grad.data_ptr() + ) + if self._params_to_addr[param.name] != current_ptr: + raise ValueError( + "The address of the grad/main_grad of the param has been changed during training, " + "which is not allowed for dp/sharding overlap with pp. " + "This may be caused by some non-inplace operations on the grad/main_grad. " + "Please use the inplace version of the operations or disable the overlapping." + ) self._params_step_dict[param.name] += 1 diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e0d89932a2921336cb994bdcfc5f0920b4dc5be5..84786e2ace3ed90a65c174659f9d26ef350e9c40 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -75,6 +75,7 @@ if(NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) list(REMOVE_ITEM TEST_OPS test_fused_attention_pass) list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass) + list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer) endif() list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) diff --git a/python/paddle/fluid/tests/unittests/test_fused_comm_buffer.py b/python/paddle/fluid/tests/unittests/test_fused_comm_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..ad771b6dfe5a2bbe220cb424a16b9ee77c18ce7b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_comm_buffer.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.distributed.fleet.meta_parallel.pp_utils.utils import ( + HOOK_ACTION, + FusedCommBuffer, +) + + +class TestFusedCommBufferGradChecker(unittest.TestCase): + def test_fused_comm_buffer_grad_checker(self): + linear = paddle.nn.Linear(10, 10) + w = linear.weight + b = linear.bias + w.main_grad = None + b.main_grad = None + buffer = FusedCommBuffer( + id=0, + params=[w, b], + comm_group=None, + acc_steps=10, + act=HOOK_ACTION.ALL_REDUCE, + ) + assert buffer.use_main_grad + buffer.add_grad(w) + buffer.add_grad(b) + w.main_grad = paddle.to_tensor([1], stop_gradient=True, dtype="float32") + try: + buffer.add_grad(w) + raise AssertionError( + "Above add_grad should raise value error, this assertion should be unreachable." + ) + except ValueError: + pass + + +if __name__ == "__main__": + unittest.main()