未验证 提交 d7e0f875 编写于 作者: N niuliling123 提交者: GitHub

Change the print in debugging to RuntimeError (#56622)

上级 c5fc413a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
...@@ -115,6 +115,8 @@ def check_layer_numerics(func): ...@@ -115,6 +115,8 @@ def check_layer_numerics(func):
if args: if args:
# Set temp data and temp.gradient = False # Set temp data and temp.gradient = False
start_data = args[0] start_data = args[0]
if not isinstance(start_data, paddle.Tensor):
raise RuntimeError("First input of this layer must be tensor.")
start_data.stop_gradient = False start_data.stop_gradient = False
modified_args = list(args) # Convert args to a mutable list modified_args = list(args) # Convert args to a mutable list
# Set FLAGS_check_nan_inf = 1 # Set FLAGS_check_nan_inf = 1
...@@ -125,7 +127,7 @@ def check_layer_numerics(func): ...@@ -125,7 +127,7 @@ def check_layer_numerics(func):
out = _C_ops.disable_check_model_nan_inf(out_data, 0) out = _C_ops.disable_check_model_nan_inf(out_data, 0)
return out return out
else: else:
print("No elements found in args") raise RuntimeError("No elements found in args.")
out = func(self, *args, **kwargs) out = func(self, *args, **kwargs)
return out return out
......
...@@ -138,6 +138,37 @@ class TestCheckLayerNumerics(unittest.TestCase): ...@@ -138,6 +138,37 @@ class TestCheckLayerNumerics(unittest.TestCase):
loss.backward() loss.backward()
adam.step() adam.step()
def test_error_no_element(self):
class MyLayer(paddle.nn.Layer):
def __init__(self, dtype):
super().__init__()
self._w = self.create_parameter([2, 3], dtype=dtype)
@paddle.amp.debugging.check_layer_numerics
def forward(self):
return self._w
with self.assertRaises(RuntimeError):
dtype = 'float32'
model = MyLayer(dtype)
data = model()
def test_error_type_error(self):
class MyLayer(paddle.nn.Layer):
def __init__(self, dtype):
super().__init__()
self._w = self.create_parameter([2, 3], dtype=dtype)
@paddle.amp.debugging.check_layer_numerics
def forward(self, x):
return self._w * x
x = 1
with self.assertRaises(RuntimeError):
dtype = 'float32'
model = MyLayer(dtype)
data = model(x)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册