未验证 提交 9dbfadab 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][B017] catch more specific exceptions in unittests (#52553)

上级 160dfd01
...@@ -41,10 +41,6 @@ select = [ ...@@ -41,10 +41,6 @@ select = [
"B002", "B002",
"B003", "B003",
"B004", "B004",
# "B005",
# "B006",
# "B007",
# "B008",
"B009", "B009",
"B010", "B010",
"B011", "B011",
...@@ -53,65 +49,30 @@ select = [ ...@@ -53,65 +49,30 @@ select = [
"B014", "B014",
"B015", "B015",
"B016", "B016",
# "B017", "B017",
"B018", "B018",
"B019", "B019",
"B020", "B020",
"B021", "B021",
"B022", "B022",
# "B023",
# "B024",
"B025", "B025",
# "B026",
# "B027",
# "B028",
"B029", "B029",
# "B030",
"B032", "B032",
# "B904",
# Pylint # Pylint
"PLC0414", "PLC0414",
# "PLC1901",
"PLC3002", "PLC3002",
"PLE0100", "PLE0100",
"PLE0101", "PLE0101",
# "PLE0116",
# "PLE0117",
# "PLE0118",
"PLE0604", "PLE0604",
"PLE0605", "PLE0605",
"PLE1142", "PLE1142",
"PLE1205", "PLE1205",
"PLE1206", "PLE1206",
"PLE1307", "PLE1307",
# "PLE1310",
# "PLE1507",
"PLE2502", "PLE2502",
# "PLE2510",
# "PLE2512",
# "PLE2513",
# "PLE2514",
# "PLE2515",
# "PLR0133",
"PLR0206", "PLR0206",
"PLR0402", "PLR0402",
# "PLR0911",
# "PLR0912",
# "PLR0913",
# "PLR0915",
# "PLR1701",
# "PLR1711",
# "PLR1722",
# "PLR2004",
# "PLR5501",
# "PLW0120",
# "PLW0129",
# "PLW0602",
# "PLW0603",
# "PLW0711",
# "PLW1508",
# "PLW2901",
] ]
unfixable = [ unfixable = [
"NPY001" "NPY001"
......
...@@ -43,7 +43,7 @@ class Rank: ...@@ -43,7 +43,7 @@ class Rank:
return ( return (
'{' '{'
+ 'rank={};'.format(self.kind) + f'rank={self.kind};'
+ ','.join([node.name for node in self.nodes]) + ','.join([node.name for node in self.nodes])
+ '}' + '}'
) )
...@@ -97,7 +97,7 @@ class Graph: ...@@ -97,7 +97,7 @@ class Graph:
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
) )
logging.warning("write block debug graph to {}".format(image_path)) logging.warning(f"write block debug graph to {image_path}")
return image_path return image_path
def show(self, dot_path): def show(self, dot_path):
...@@ -125,13 +125,11 @@ class Graph: ...@@ -125,13 +125,11 @@ class Graph:
def __str__(self): def __str__(self):
reprs = [ reprs = [
'digraph G {', 'digraph G {',
'title = {}'.format(crepr(self.title)), f'title = {crepr(self.title)}',
] ]
for attr in self.attrs: for attr in self.attrs:
reprs.append( reprs.append(f"{attr}={crepr(self.attrs[attr])};")
"{key}={value};".format(key=attr, value=crepr(self.attrs[attr]))
)
reprs.append(self._rank_repr()) reprs.append(self._rank_repr())
...@@ -161,8 +159,7 @@ class Node: ...@@ -161,8 +159,7 @@ class Node:
label=self.label, label=self.label,
extra=',' extra=','
+ ','.join( + ','.join(
"%s=%s" % (key, crepr(value)) f"{key}={crepr(value)}" for key, value in self.attrs.items()
for key, value in self.attrs.items()
) )
if self.attrs if self.attrs
else "", else "",
...@@ -191,8 +188,7 @@ class Edge: ...@@ -191,8 +188,7 @@ class Edge:
if not self.attrs if not self.attrs
else "[" else "["
+ ','.join( + ','.join(
"{}={}".format(attr[0], crepr(attr[1])) f"{attr[0]}={crepr(attr[1])}" for attr in self.attrs.items()
for attr in self.attrs.items()
) )
+ "]", + "]",
) )
...@@ -292,5 +288,5 @@ class GraphPreviewGenerator: ...@@ -292,5 +288,5 @@ class GraphPreviewGenerator:
source, source,
target, target,
color="#00000" if not highlight else "orange", color="#00000" if not highlight else "orange",
**kwargs **kwargs,
) )
...@@ -39,7 +39,7 @@ class CheckPassConflictTest2(PassConflictChecker): ...@@ -39,7 +39,7 @@ class CheckPassConflictTest2(PassConflictChecker):
] ]
def test_resnet(self): def test_resnet(self):
with self.assertRaises(Exception): with self.assertRaises(Exception): # noqa: B017
self.check_main(resnet_model, batch_size=32) self.check_main(resnet_model, batch_size=32)
......
...@@ -72,7 +72,9 @@ class TRTDynamicShapeOutOfBound1Test(TRTDynamicShapeTest): ...@@ -72,7 +72,9 @@ class TRTDynamicShapeOutOfBound1Test(TRTDynamicShapeTest):
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
use_gpu = True use_gpu = True
with self.assertRaises(Exception): with self.assertRaisesRegex(
ValueError, "The fed Variable 'data' should have dimensions"
):
self.check_output_with_option(use_gpu) self.check_output_with_option(use_gpu)
...@@ -99,7 +101,9 @@ class TRTDynamicShapeOutOfBound3Test(TRTDynamicShapeTest): ...@@ -99,7 +101,9 @@ class TRTDynamicShapeOutOfBound3Test(TRTDynamicShapeTest):
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
use_gpu = True use_gpu = True
with self.assertRaises(Exception): with self.assertRaisesRegex(
ValueError, "The fed Variable 'data' should have dimensions"
):
self.check_output_with_option(use_gpu) self.check_output_with_option(use_gpu)
......
...@@ -16,11 +16,12 @@ import unittest ...@@ -16,11 +16,12 @@ import unittest
import paddle import paddle
from paddle.jit import to_static from paddle.jit import to_static
from paddle.jit.dy2static.convert_call_func import translator_logger
def dyfunc_generator(): def dyfunc_generator():
for i in range(100): for i in range(100):
yield paddle.fluid.dygraph.to_variable([i] * 10) yield paddle.to_tensor([i] * 10)
def main_func(): def main_func():
...@@ -31,8 +32,17 @@ def main_func(): ...@@ -31,8 +32,17 @@ def main_func():
class TestConvertGenerator(unittest.TestCase): class TestConvertGenerator(unittest.TestCase):
def test_raise_error(self): def test_raise_error(self):
with self.assertRaises(Exception): translator_logger.verbosity_level = 1
with self.assertLogs(
translator_logger.logger_name, level='WARNING'
) as cm:
to_static(main_func)() to_static(main_func)()
self.assertRegex(
cm.output[0],
"Your function:`dyfunc_generator` doesn't support "
"to transform to static function because it is a "
"generator function",
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -299,10 +299,10 @@ class TestTransformsCV2(unittest.TestCase): ...@@ -299,10 +299,10 @@ class TestTransformsCV2(unittest.TestCase):
trans_batch = transforms.Compose([transforms.Resize(-1)]) trans_batch = transforms.Compose([transforms.Resize(-1)])
with self.assertRaises(Exception): with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans) self.do_transform(trans)
with self.assertRaises(Exception): with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans_batch) self.do_transform(trans_batch)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -411,22 +411,35 @@ class TestTransformsCV2(unittest.TestCase): ...@@ -411,22 +411,35 @@ class TestTransformsCV2(unittest.TestCase):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
transform = transforms.BrightnessTransform('0.1', keys='a') transform = transforms.BrightnessTransform('0.1', keys='a')
with self.assertRaises(Exception): with self.assertRaisesRegex(
AssertionError, "scale should be a tuple or list"
):
transform = transforms.RandomErasing(scale=0.5) transform = transforms.RandomErasing(scale=0.5)
with self.assertRaises(Exception): with self.assertRaisesRegex(
AssertionError, "ratio should be a tuple or list"
):
transform = transforms.RandomErasing(ratio=0.8) transform = transforms.RandomErasing(ratio=0.8)
with self.assertRaises(Exception): with self.assertRaisesRegex(
AssertionError,
r"scale should be of kind \(min, max\) and in range \[0, 1\]",
):
transform = transforms.RandomErasing(scale=(10, 0.4)) transform = transforms.RandomErasing(scale=(10, 0.4))
with self.assertRaises(Exception): with self.assertRaisesRegex(
AssertionError, r"ratio should be of kind \(min, max\)"
):
transform = transforms.RandomErasing(ratio=(3.3, 0.3)) transform = transforms.RandomErasing(ratio=(3.3, 0.3))
with self.assertRaises(Exception): with self.assertRaisesRegex(
AssertionError, r"The probability should be in range \[0, 1\]"
):
transform = transforms.RandomErasing(prob=1.5) transform = transforms.RandomErasing(prob=1.5)
with self.assertRaises(Exception): with self.assertRaisesRegex(
ValueError, r"value must be 'random' when type is str"
):
transform = transforms.RandomErasing(value="0") transform = transforms.RandomErasing(value="0")
def test_info(self): def test_info(self):
...@@ -571,10 +584,10 @@ class TestTransformsTensor(TestTransformsCV2): ...@@ -571,10 +584,10 @@ class TestTransformsTensor(TestTransformsCV2):
trans_batch = transforms.Compose([transforms.Resize(-1)]) trans_batch = transforms.Compose([transforms.Resize(-1)])
with self.assertRaises(Exception): with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans) self.do_transform(trans)
with self.assertRaises(Exception): with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans_batch) self.do_transform(trans_batch)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册