未验证 提交 161dad50 编写于 作者: Y Yuang Liu 提交者: GitHub

[hybrid performace] Add addr check for overlap grad fusion. (#54543) (#54552)

上级 37bda166
......@@ -167,11 +167,12 @@ class FusedCommBuffer:
self._acc_steps = acc_steps
self._comm_group = comm_group
use_main_grad = hasattr(self._params[0], "main_grad")
self.use_main_grad = hasattr(self._params[0], "main_grad")
self._task = None
self._params_step_dict = {}
self._params_checked_in = 0
self._params_to_addr = {}
self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
......@@ -186,7 +187,20 @@ class FusedCommBuffer:
self._init_step_dict()
self.grad_storage = flatten_dense_tensors(self._params, use_main_grad)
self.grad_storage = flatten_dense_tensors(
self._params, self.use_main_grad
)
self._record_addr()
def _record_addr(self):
for param in self._params:
addr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
self._params_to_addr[param.name] = addr
def _init_step_dict(self):
for p in self._params:
......@@ -206,6 +220,18 @@ class FusedCommBuffer:
def add_grad(self, param):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
if self._params_to_addr[param.name] != current_ptr:
raise ValueError(
"The address of the grad/main_grad of the param has been changed during training, "
"which is not allowed for dp/sharding overlap with pp. "
"This may be caused by some non-inplace operations on the grad/main_grad. "
"Please use the inplace version of the operations or disable the overlapping."
)
self._params_step_dict[param.name] += 1
......
......@@ -75,6 +75,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass)
list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer)
endif()
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import (
HOOK_ACTION,
FusedCommBuffer,
)
class TestFusedCommBufferGradChecker(unittest.TestCase):
def test_fused_comm_buffer_grad_checker(self):
linear = paddle.nn.Linear(10, 10)
w = linear.weight
b = linear.bias
w.main_grad = None
b.main_grad = None
buffer = FusedCommBuffer(
id=0,
params=[w, b],
comm_group=None,
acc_steps=10,
act=HOOK_ACTION.ALL_REDUCE,
)
assert buffer.use_main_grad
buffer.add_grad(w)
buffer.add_grad(b)
w.main_grad = paddle.to_tensor([1], stop_gradient=True, dtype="float32")
try:
buffer.add_grad(w)
raise AssertionError(
"Above add_grad should raise value error, this assertion should be unreachable."
)
except ValueError:
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册