未验证 提交 c5e1011b 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][F821] refactor unittests utility function `parameterized` related code (#47869)

上级 18549417
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import functools
import inspect
import re
import sys
from unittest import SkipTest
import numpy as np
import config
......@@ -72,7 +73,6 @@ def parameterize_cls(fields, values=None):
def parameterize_func(
input, name_func=None, doc_func=None, skip_on_empty=False
):
doc_func = doc_func or default_doc_func
name_func = name_func or default_name_func
def wrapper(f, instance=None):
......@@ -87,7 +87,7 @@ def parameterize_func(
"`parameterized.expand([], skip_on_empty=True)` to skip "
"this test when the input is empty)"
)
return wraps(f)(skip_on_empty_helper)
return functools.wraps(f)(skip_on_empty_helper)
digits = len(str(len(parameters) - 1))
for num, p in enumerate(parameters):
......@@ -100,7 +100,7 @@ def parameterize_func(
# patch objects between new functions
nf = reapply_patches_if_need(f)
frame_locals[name] = param_as_standalone_func(p, nf, name)
frame_locals[name].__doc__ = doc_func(f, num, p)
frame_locals[name].__doc__ = f.__doc__
# Delete original patches to prevent new function from evaluating
# original patching object as well as re-constrfucted patches.
......@@ -113,7 +113,7 @@ def parameterize_func(
def reapply_patches_if_need(func):
def dummy_wrapper(orgfunc):
@wraps(orgfunc)
@functools.wraps(orgfunc)
def dummy_func(*args, **kwargs):
return orgfunc(*args, **kwargs)
......@@ -142,27 +142,6 @@ def default_name_func(func, num, p):
return base_name + name_suffix
def default_doc_func(func, num, p):
if func.__doc__ is None:
return None
all_args_with_values = parameterized_argument_value_pairs(func, p)
# Assumes that the function passed is a bound method.
descs = ["%s=%s" % (n, short_repr(v)) for n, v in all_args_with_values]
# The documentation might be a multiline string, so split it
# and just work with the first string, ignoring the period
# at the end if there is one.
first, nl, rest = func.__doc__.lstrip().partition("\n")
suffix = ""
if first.endswith("."):
suffix = "."
first = first[:-1]
args = "%s[with %s]" % (len(first) and " " or "", ", ".join(descs))
return "".join(to_text(x) for x in [first.rstrip(), args, suffix, nl, rest])
def param_as_standalone_func(p, func, name):
@functools.wraps(func)
def standalone_func(*a):
......@@ -252,22 +231,6 @@ def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s))
@contextlib.contextmanager
def stgraph(func, *args):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
# alias
parameterize = parameterize_func
param_cls = parameterize_cls
......
......@@ -1817,13 +1817,14 @@ class TestFftShift(unittest.TestCase):
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = paddle.fft.fftshift(input, axes)
input = paddle.static.data(
'input', self.x.shape, dtype=self.x.dtype
)
output = paddle.fft.fftshift(input, self.axes)
exe = paddle.static.Executor(place)
exe = paddle.static.Executor(self.place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
[output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output])
paddle.disable_static()
......@@ -1848,13 +1849,14 @@ class TestIfftShift(unittest.TestCase):
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = paddle.fft.ifftshift(input, axes)
input = paddle.static.data(
'input', self.x.shape, dtype=self.x.dtype
)
output = paddle.fft.ifftshift(input, self.axes)
exe = paddle.static.Executor(place)
exe = paddle.static.Executor(self.place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
[output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output])
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册