diff --git a/python/paddle/fluid/tests/unittests/distribution/parameterize.py b/python/paddle/fluid/tests/unittests/distribution/parameterize.py index f962efc7139997679329b2759f06ddc1ecd0735c..b2e5ae1fe33563cd5f0ea009745cd47ba3decf92 100644 --- a/python/paddle/fluid/tests/unittests/distribution/parameterize.py +++ b/python/paddle/fluid/tests/unittests/distribution/parameterize.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections -import contextlib import functools import inspect import re import sys +from unittest import SkipTest + import numpy as np import config @@ -72,7 +73,6 @@ def parameterize_cls(fields, values=None): def parameterize_func( input, name_func=None, doc_func=None, skip_on_empty=False ): - doc_func = doc_func or default_doc_func name_func = name_func or default_name_func def wrapper(f, instance=None): @@ -87,7 +87,7 @@ def parameterize_func( "`parameterized.expand([], skip_on_empty=True)` to skip " "this test when the input is empty)" ) - return wraps(f)(skip_on_empty_helper) + return functools.wraps(f)(skip_on_empty_helper) digits = len(str(len(parameters) - 1)) for num, p in enumerate(parameters): @@ -100,7 +100,7 @@ def parameterize_func( # patch objects between new functions nf = reapply_patches_if_need(f) frame_locals[name] = param_as_standalone_func(p, nf, name) - frame_locals[name].__doc__ = doc_func(f, num, p) + frame_locals[name].__doc__ = f.__doc__ # Delete original patches to prevent new function from evaluating # original patching object as well as re-constrfucted patches. @@ -113,7 +113,7 @@ def parameterize_func( def reapply_patches_if_need(func): def dummy_wrapper(orgfunc): - @wraps(orgfunc) + @functools.wraps(orgfunc) def dummy_func(*args, **kwargs): return orgfunc(*args, **kwargs) @@ -142,27 +142,6 @@ def default_name_func(func, num, p): return base_name + name_suffix -def default_doc_func(func, num, p): - if func.__doc__ is None: - return None - - all_args_with_values = parameterized_argument_value_pairs(func, p) - - # Assumes that the function passed is a bound method. - descs = ["%s=%s" % (n, short_repr(v)) for n, v in all_args_with_values] - - # The documentation might be a multiline string, so split it - # and just work with the first string, ignoring the period - # at the end if there is one. - first, nl, rest = func.__doc__.lstrip().partition("\n") - suffix = "" - if first.endswith("."): - suffix = "." - first = first[:-1] - args = "%s[with %s]" % (len(first) and " " or "", ", ".join(descs)) - return "".join(to_text(x) for x in [first.rstrip(), args, suffix, nl, rest]) - - def param_as_standalone_func(p, func, name): @functools.wraps(func) def standalone_func(*a): @@ -252,22 +231,6 @@ def to_safe_name(s): return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) -@contextlib.contextmanager -def stgraph(func, *args): - """static graph exec context""" - paddle.enable_static() - mp, sp = paddle.static.Program(), paddle.static.Program() - with paddle.static.program_guard(mp, sp): - input = paddle.static.data('input', x.shape, dtype=x.dtype) - output = func(input, n, axes, norm) - - exe = paddle.static.Executor(place) - exe.run(sp) - [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) - yield output - paddle.disable_static() - - # alias parameterize = parameterize_func param_cls = parameterize_cls diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py index 35828ed160664c627ec072a6f5c6720b9eac07ea..5386ac2dc2b7364676de1cfb809d704bcd1687af 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py @@ -1817,13 +1817,14 @@ class TestFftShift(unittest.TestCase): paddle.enable_static() mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): - input = paddle.static.data('input', x.shape, dtype=x.dtype) - output = paddle.fft.fftshift(input, axes) + input = paddle.static.data( + 'input', self.x.shape, dtype=self.x.dtype + ) + output = paddle.fft.fftshift(input, self.axes) - exe = paddle.static.Executor(place) + exe = paddle.static.Executor(self.place) exe.run(sp) - [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) - yield output + [output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output]) paddle.disable_static() @@ -1848,13 +1849,14 @@ class TestIfftShift(unittest.TestCase): paddle.enable_static() mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): - input = paddle.static.data('input', x.shape, dtype=x.dtype) - output = paddle.fft.ifftshift(input, axes) + input = paddle.static.data( + 'input', self.x.shape, dtype=self.x.dtype + ) + output = paddle.fft.ifftshift(input, self.axes) - exe = paddle.static.Executor(place) + exe = paddle.static.Executor(self.place) exe.run(sp) - [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) - yield output + [output] = exe.run(mp, feed={'input': self.x}, fetch_list=[output]) paddle.disable_static()