未验证 提交 9dbfadab 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][B017] catch more specific exceptions in unittests (#52553)

上级 160dfd01
......@@ -41,10 +41,6 @@ select = [
"B002",
"B003",
"B004",
# "B005",
# "B006",
# "B007",
# "B008",
"B009",
"B010",
"B011",
......@@ -53,65 +49,30 @@ select = [
"B014",
"B015",
"B016",
# "B017",
"B017",
"B018",
"B019",
"B020",
"B021",
"B022",
# "B023",
# "B024",
"B025",
# "B026",
# "B027",
# "B028",
"B029",
# "B030",
"B032",
# "B904",
# Pylint
"PLC0414",
# "PLC1901",
"PLC3002",
"PLE0100",
"PLE0101",
# "PLE0116",
# "PLE0117",
# "PLE0118",
"PLE0604",
"PLE0605",
"PLE1142",
"PLE1205",
"PLE1206",
"PLE1307",
# "PLE1310",
# "PLE1507",
"PLE2502",
# "PLE2510",
# "PLE2512",
# "PLE2513",
# "PLE2514",
# "PLE2515",
# "PLR0133",
"PLR0206",
"PLR0402",
# "PLR0911",
# "PLR0912",
# "PLR0913",
# "PLR0915",
# "PLR1701",
# "PLR1711",
# "PLR1722",
# "PLR2004",
# "PLR5501",
# "PLW0120",
# "PLW0129",
# "PLW0602",
# "PLW0603",
# "PLW0711",
# "PLW1508",
# "PLW2901",
]
unfixable = [
"NPY001"
......
......@@ -43,7 +43,7 @@ class Rank:
return (
'{'
+ 'rank={};'.format(self.kind)
+ f'rank={self.kind};'
+ ','.join([node.name for node in self.nodes])
+ '}'
)
......@@ -97,7 +97,7 @@ class Graph:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
logging.warning("write block debug graph to {}".format(image_path))
logging.warning(f"write block debug graph to {image_path}")
return image_path
def show(self, dot_path):
......@@ -125,13 +125,11 @@ class Graph:
def __str__(self):
reprs = [
'digraph G {',
'title = {}'.format(crepr(self.title)),
f'title = {crepr(self.title)}',
]
for attr in self.attrs:
reprs.append(
"{key}={value};".format(key=attr, value=crepr(self.attrs[attr]))
)
reprs.append(f"{attr}={crepr(self.attrs[attr])};")
reprs.append(self._rank_repr())
......@@ -161,8 +159,7 @@ class Node:
label=self.label,
extra=','
+ ','.join(
"%s=%s" % (key, crepr(value))
for key, value in self.attrs.items()
f"{key}={crepr(value)}" for key, value in self.attrs.items()
)
if self.attrs
else "",
......@@ -191,8 +188,7 @@ class Edge:
if not self.attrs
else "["
+ ','.join(
"{}={}".format(attr[0], crepr(attr[1]))
for attr in self.attrs.items()
f"{attr[0]}={crepr(attr[1])}" for attr in self.attrs.items()
)
+ "]",
)
......@@ -292,5 +288,5 @@ class GraphPreviewGenerator:
source,
target,
color="#00000" if not highlight else "orange",
**kwargs
**kwargs,
)
......@@ -39,7 +39,7 @@ class CheckPassConflictTest2(PassConflictChecker):
]
def test_resnet(self):
with self.assertRaises(Exception):
with self.assertRaises(Exception): # noqa: B017
self.check_main(resnet_model, batch_size=32)
......
......@@ -72,7 +72,9 @@ class TRTDynamicShapeOutOfBound1Test(TRTDynamicShapeTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
with self.assertRaises(Exception):
with self.assertRaisesRegex(
ValueError, "The fed Variable 'data' should have dimensions"
):
self.check_output_with_option(use_gpu)
......@@ -99,7 +101,9 @@ class TRTDynamicShapeOutOfBound3Test(TRTDynamicShapeTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
with self.assertRaises(Exception):
with self.assertRaisesRegex(
ValueError, "The fed Variable 'data' should have dimensions"
):
self.check_output_with_option(use_gpu)
......
......@@ -16,11 +16,12 @@ import unittest
import paddle
from paddle.jit import to_static
from paddle.jit.dy2static.convert_call_func import translator_logger
def dyfunc_generator():
for i in range(100):
yield paddle.fluid.dygraph.to_variable([i] * 10)
yield paddle.to_tensor([i] * 10)
def main_func():
......@@ -31,8 +32,17 @@ def main_func():
class TestConvertGenerator(unittest.TestCase):
def test_raise_error(self):
with self.assertRaises(Exception):
translator_logger.verbosity_level = 1
with self.assertLogs(
translator_logger.logger_name, level='WARNING'
) as cm:
to_static(main_func)()
self.assertRegex(
cm.output[0],
"Your function:`dyfunc_generator` doesn't support "
"to transform to static function because it is a "
"generator function",
)
if __name__ == '__main__':
......
......@@ -299,10 +299,10 @@ class TestTransformsCV2(unittest.TestCase):
trans_batch = transforms.Compose([transforms.Resize(-1)])
with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans)
with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans_batch)
with self.assertRaises(ValueError):
......@@ -411,22 +411,35 @@ class TestTransformsCV2(unittest.TestCase):
with self.assertRaises(NotImplementedError):
transform = transforms.BrightnessTransform('0.1', keys='a')
with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, "scale should be a tuple or list"
):
transform = transforms.RandomErasing(scale=0.5)
with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, "ratio should be a tuple or list"
):
transform = transforms.RandomErasing(ratio=0.8)
with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError,
r"scale should be of kind \(min, max\) and in range \[0, 1\]",
):
transform = transforms.RandomErasing(scale=(10, 0.4))
with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, r"ratio should be of kind \(min, max\)"
):
transform = transforms.RandomErasing(ratio=(3.3, 0.3))
with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, r"The probability should be in range \[0, 1\]"
):
transform = transforms.RandomErasing(prob=1.5)
with self.assertRaises(Exception):
with self.assertRaisesRegex(
ValueError, r"value must be 'random' when type is str"
):
transform = transforms.RandomErasing(value="0")
def test_info(self):
......@@ -571,10 +584,10 @@ class TestTransformsTensor(TestTransformsCV2):
trans_batch = transforms.Compose([transforms.Resize(-1)])
with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans)
with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans_batch)
with self.assertRaises(ValueError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册