未验证 提交 c5e1011b 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][F821] refactor unittests utility function `parameterized` related code (#47869)

上级 18549417
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import contextlib
import functools import functools
import inspect import inspect
import re import re
import sys import sys
from unittest import SkipTest
import numpy as np import numpy as np
import config import config
...@@ -72,7 +73,6 @@ def parameterize_cls(fields, values=None): ...@@ -72,7 +73,6 @@ def parameterize_cls(fields, values=None):
def parameterize_func( def parameterize_func(
input, name_func=None, doc_func=None, skip_on_empty=False input, name_func=None, doc_func=None, skip_on_empty=False
): ):
doc_func = doc_func or default_doc_func
name_func = name_func or default_name_func name_func = name_func or default_name_func
def wrapper(f, instance=None): def wrapper(f, instance=None):
...@@ -87,7 +87,7 @@ def parameterize_func( ...@@ -87,7 +87,7 @@ def parameterize_func(
"`parameterized.expand([], skip_on_empty=True)` to skip " "`parameterized.expand([], skip_on_empty=True)` to skip "
"this test when the input is empty)" "this test when the input is empty)"
) )
return wraps(f)(skip_on_empty_helper) return functools.wraps(f)(skip_on_empty_helper)
digits = len(str(len(parameters) - 1)) digits = len(str(len(parameters) - 1))
for num, p in enumerate(parameters): for num, p in enumerate(parameters):
...@@ -100,7 +100,7 @@ def parameterize_func( ...@@ -100,7 +100,7 @@ def parameterize_func(
# patch objects between new functions # patch objects between new functions
nf = reapply_patches_if_need(f) nf = reapply_patches_if_need(f)
frame_locals[name] = param_as_standalone_func(p, nf, name) frame_locals[name] = param_as_standalone_func(p, nf, name)
frame_locals[name].__doc__ = doc_func(f, num, p) frame_locals[name].__doc__ = f.__doc__
# Delete original patches to prevent new function from evaluating # Delete original patches to prevent new function from evaluating
# original patching object as well as re-constrfucted patches. # original patching object as well as re-constrfucted patches.
...@@ -113,7 +113,7 @@ def parameterize_func( ...@@ -113,7 +113,7 @@ def parameterize_func(
def reapply_patches_if_need(func): def reapply_patches_if_need(func):
def dummy_wrapper(orgfunc): def dummy_wrapper(orgfunc):
@wraps(orgfunc) @functools.wraps(orgfunc)
def dummy_func(*args, **kwargs): def dummy_func(*args, **kwargs):
return orgfunc(*args, **kwargs) return orgfunc(*args, **kwargs)
...@@ -142,27 +142,6 @@ def default_name_func(func, num, p): ...@@ -142,27 +142,6 @@ def default_name_func(func, num, p):
return base_name + name_suffix return base_name + name_suffix
def default_doc_func(func, num, p):
if func.__doc__ is None:
return None
all_args_with_values = parameterized_argument_value_pairs(func, p)
# Assumes that the function passed is a bound method.
descs = ["%s=%s" % (n, short_repr(v)) for n, v in all_args_with_values]
# The documentation might be a multiline string, so split it
# and just work with the first string, ignoring the period
# at the end if there is one.
first, nl, rest = func.__doc__.lstrip().partition("\n")
suffix = ""
if first.endswith("."):
suffix = "."
first = first[:-1]
args = "%s[with %s]" % (len(first) and " " or "", ", ".join(descs))
return "".join(to_text(x) for x in [first.rstrip(), args, suffix, nl, rest])
def param_as_standalone_func(p, func, name): def param_as_standalone_func(p, func, name):
@functools.wraps(func) @functools.wraps(func)
def standalone_func(*a): def standalone_func(*a):
...@@ -252,22 +231,6 @@ def to_safe_name(s): ...@@ -252,22 +231,6 @@ def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) return str(re.sub("[^a-zA-Z0-9_]+", "_", s))
@contextlib.contextmanager
def stgraph(func, *args):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
# alias # alias
parameterize = parameterize_func parameterize = parameterize_func
param_cls = parameterize_cls param_cls = parameterize_cls
......
...@@ -1817,13 +1817,14 @@ class TestFftShift(unittest.TestCase): ...@@ -1817,13 +1817,14 @@ class TestFftShift(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype) input = paddle.static.data(
output = paddle.fft.fftshift(input, axes) 'input', self.x.shape, dtype=self.x.dtype
)
output = paddle.fft.fftshift(input, self.axes)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(self.place)
exe.run(sp) exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) [output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output])
yield output
paddle.disable_static() paddle.disable_static()
...@@ -1848,13 +1849,14 @@ class TestIfftShift(unittest.TestCase): ...@@ -1848,13 +1849,14 @@ class TestIfftShift(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype) input = paddle.static.data(
output = paddle.fft.ifftshift(input, axes) 'input', self.x.shape, dtype=self.x.dtype
)
output = paddle.fft.ifftshift(input, self.axes)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(self.place)
exe.run(sp) exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) [output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output])
yield output
paddle.disable_static() paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册