From 1980e33a901efa5128e7799a83bcd35ee8ada199 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 2 Mar 2022 18:54:54 +0800 Subject: [PATCH] add check for backward hook (#40041) * add check for backward hook * refine ut --- paddle/fluid/imperative/basic_engine.cc | 1 + .../fluid/imperative/gradient_accumulator.cc | 1 + .../fluid/imperative/gradient_accumulator.h | 24 ++++++++++++ .../test_imperative_auto_mixed_precision.py | 38 ++++++++++++++++++- 4 files changed, 62 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 8373c7fe50d..7416d206fc4 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -317,6 +317,7 @@ static std::shared_ptr> CallGradientHooks( auto tmp_var = var; for (const auto& hook_pair : var->GetVariableWrapperHooks()) { tmp_var = (*hook_pair.second)(tmp_var); + CheckVar(var, tmp_var); } (*tmp_ins_ptr)[pair.first][i] = tmp_var; } diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 0abc5ad90e2..12aa13bbacc 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -732,6 +732,7 @@ void GradientAccumulator::CallGradientHooks() { << var_->GetVariableWrapperHooks().size(); for (const auto& hook_pair : var_->GetVariableWrapperHooks()) { tmp_var = (*hook_pair.second)(tmp_var); + CheckVar(inner_var_, tmp_var); } inner_var_ = tmp_var; } diff --git a/paddle/fluid/imperative/gradient_accumulator.h b/paddle/fluid/imperative/gradient_accumulator.h index e74711c2a79..03f6775defc 100644 --- a/paddle/fluid/imperative/gradient_accumulator.h +++ b/paddle/fluid/imperative/gradient_accumulator.h @@ -179,5 +179,29 @@ void SelectedRowsAddTensor(const VarType& src_selected_rows_var, template void TensorAdd(const VarType& src, VarType* dst); +inline void CheckVar(const std::shared_ptr& pre, + const std::shared_ptr& post) { + if (pre->IsEmpty() && !post->IsEmpty()) { + PADDLE_THROW(platform::errors::PermissionDenied( + "The tensor(%s) in before and after hook are not consistent", + pre->Name())); + } + if (!pre->IsEmpty() && !post->IsEmpty()) { + VLOG(4) << pre->DataType() << " " << post->DataType(); + PADDLE_ENFORCE_EQ( + pre->DataType(), post->DataType(), + platform::errors::PermissionDenied( + "The dtype of tensor(%s) before(%s) and after(%s) hook are not " + "consistent", + pre->Name(), framework::DataTypeToString(pre->DataType()), + framework::DataTypeToString(post->DataType()))); + PADDLE_ENFORCE_EQ(pre->Place(), post->Place(), + platform::errors::PermissionDenied( + "The place of tensor(%s) before(%s) and after(%s) " + "hook are not consistent", + pre->Name(), pre->Place(), post->Place())); + } +} + } // namespace imperative } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 5cb72512f99..2011a35db68 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -1156,7 +1156,7 @@ class TestBf16(unittest.TestCase): out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) -class TestPyLayerWithAmp(unittest.TestCase): +class TestAmpWithPyLyer(unittest.TestCase): def test_pylayer(self): class MyMM(PyLayer): @staticmethod @@ -1168,7 +1168,7 @@ class TestPyLayerWithAmp(unittest.TestCase): def backward(ctx, grad): a, b = ctx.saved_tensor() # NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast() - # thus, the mm operation raise errors because of the dtype of inputs are inconsistent + # thus, the mm operation raise errors because of the dtype of inputs are inconsistent before. return grad.mm(b.t()), a.t().mm(grad) x = paddle.rand([10, 10]) @@ -1182,5 +1182,39 @@ class TestPyLayerWithAmp(unittest.TestCase): loss.backward() +class TestAmpWithHook(unittest.TestCase): + def test_hook_change_dtype(self): + with paddle.fluid.dygraph.guard(): + v = paddle.rand([3, 3]) + v.stop_gradient = False + + def foo(grad): + print('grad', grad, grad.dtype) # grad's dtype is float32 + res = paddle.mm(grad, grad) # mm runs in fp16 + print('res', res, res.dtype) # res's dtype is float16 + return res + + v.register_hook(foo) + with paddle.amp.auto_cast(): + a = paddle.mm(v, v) + loss = a.sum() + self.assertRaises(RuntimeError, loss.backward) + + def test_hook_change_place(self): + with paddle.fluid.dygraph.guard(): + v = paddle.rand([3, 3]) + v.stop_gradient = False + + def foo(grad): + res = grad.cpu() # change place + return res + + v.register_hook(foo) + with paddle.amp.auto_cast(): + a = paddle.mm(v, v) + loss = a.sum() + self.assertRaises(RuntimeError, loss.backward) + + if __name__ == '__main__': unittest.main() -- GitLab