未验证 提交 d7e0f875 编写于 作者: N niuliling123 提交者: GitHub

Change the print in debugging to RuntimeError (#56622)

上级 c5fc413a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
......@@ -115,6 +115,8 @@ def check_layer_numerics(func):
if args:
# Set temp data and temp.gradient = False
start_data = args[0]
if not isinstance(start_data, paddle.Tensor):
raise RuntimeError("First input of this layer must be tensor.")
start_data.stop_gradient = False
modified_args = list(args) # Convert args to a mutable list
# Set FLAGS_check_nan_inf = 1
......@@ -125,7 +127,7 @@ def check_layer_numerics(func):
out = _C_ops.disable_check_model_nan_inf(out_data, 0)
return out
else:
print("No elements found in args")
raise RuntimeError("No elements found in args.")
out = func(self, *args, **kwargs)
return out
......
......@@ -138,6 +138,37 @@ class TestCheckLayerNumerics(unittest.TestCase):
loss.backward()
adam.step()
def test_error_no_element(self):
class MyLayer(paddle.nn.Layer):
def __init__(self, dtype):
super().__init__()
self._w = self.create_parameter([2, 3], dtype=dtype)
@paddle.amp.debugging.check_layer_numerics
def forward(self):
return self._w
with self.assertRaises(RuntimeError):
dtype = 'float32'
model = MyLayer(dtype)
data = model()
def test_error_type_error(self):
class MyLayer(paddle.nn.Layer):
def __init__(self, dtype):
super().__init__()
self._w = self.create_parameter([2, 3], dtype=dtype)
@paddle.amp.debugging.check_layer_numerics
def forward(self, x):
return self._w * x
x = 1
with self.assertRaises(RuntimeError):
dtype = 'float32'
model = MyLayer(dtype)
data = model(x)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册