未验证 提交 512cb296 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][black] format dy2static unittests (#47268)

* [CodeStyle][black] format dy2static unittests

* format some missing files

* update lineno in test_origin_info

* update lineno in test_error

* update lineno
上级 cc753aa4
...@@ -33,12 +33,6 @@ repos: ...@@ -33,12 +33,6 @@ repos:
hooks: hooks:
- id: black - id: black
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
# Temporary exclude, will be formatted in a separate PR
exclude: |
(?x)^(
python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py|
python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py
)$
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 4.0.1 rev: 4.0.1
hooks: hooks:
......
...@@ -22,17 +22,18 @@ import os ...@@ -22,17 +22,18 @@ import os
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("conda build for paddlepaddle version") parser = argparse.ArgumentParser("conda build for paddlepaddle version")
parser.add_argument("--paddle_version", parser.add_argument(
"--paddle_version",
type=str, type=str,
required=True, required=True,
help="paddle version for conda build.") help="paddle version for conda build.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
class ConstantVar: class ConstantVar:
def __init__(self): def __init__(self):
self.build = r""" self.build = r"""
build: build:
...@@ -96,7 +97,10 @@ about: ...@@ -96,7 +97,10 @@ about:
self.python39 = r" - python>=3.9, <3.10" self.python39 = r" - python>=3.9, <3.10"
self.python_version = [ self.python_version = [
self.python36, self.python37, self.python38, self.python39 self.python36,
self.python37,
self.python38,
self.python39,
] ]
self.cuda101 = r""" self.cuda101 = r"""
...@@ -112,9 +116,11 @@ about: ...@@ -112,9 +116,11 @@ about:
- cudnn>=8.1, <8.2 - cudnn>=8.1, <8.2
""" """
self.cuda_info = [(self.cuda101, "cuda10.1", ".post101"), self.cuda_info = [
(self.cuda101, "cuda10.1", ".post101"),
(self.cuda102, "cuda10.2", ""), (self.cuda102, "cuda10.2", ""),
(self.cuda112, "cuda11.2", ".post112")] (self.cuda112, "cuda11.2", ".post112"),
]
self.py_str = ["py36", "py37", "py38", "py39"] self.py_str = ["py36", "py37", "py38", "py39"]
self.pip_end = ".whl --no-deps" self.pip_end = ".whl --no-deps"
self.pip_prefix_linux = "pip install /package/paddlepaddle" self.pip_prefix_linux = "pip install /package/paddlepaddle"
...@@ -122,25 +128,36 @@ about: ...@@ -122,25 +128,36 @@ about:
self.pip_gpu = "_gpu-" self.pip_gpu = "_gpu-"
self.pip_cpu = "-" self.pip_cpu = "-"
self.mac_pip = [ self.mac_pip = [
"-cp36-cp36m-macosx_10_6_intel", "-cp37-cp37m-macosx_10_6_intel", "-cp36-cp36m-macosx_10_6_intel",
"-cp38-cp38-macosx_10_14_x86_64", "-cp39-cp39-macosx_10_14_x86_64" "-cp37-cp37m-macosx_10_6_intel",
"-cp38-cp38-macosx_10_14_x86_64",
"-cp39-cp39-macosx_10_14_x86_64",
] ]
self.linux_pip = [ self.linux_pip = [
"-cp36-cp36m-linux_x86_64", "-cp37-cp37m-linux_x86_64", "-cp36-cp36m-linux_x86_64",
"-cp38-cp38-linux_x86_64", "-cp39-cp39-linux_x86_64" "-cp37-cp37m-linux_x86_64",
"-cp38-cp38-linux_x86_64",
"-cp39-cp39-linux_x86_64",
] ]
self.windows_pip = [ self.windows_pip = [
"-cp36-cp36m-win_amd64", "-cp37-cp37m-win_amd64", "-cp36-cp36m-win_amd64",
"-cp38-cp38-win_amd64", "-cp39-cp39-win_amd64" "-cp37-cp37m-win_amd64",
"-cp38-cp38-win_amd64",
"-cp39-cp39-win_amd64",
] ]
def meta_build_mac(var, python_str, paddle_version, build_var, build_name_str): def meta_build_mac(var, python_str, paddle_version, build_var, build_name_str):
package_str = """ package_str = (
"""
package: package:
name: paddlepaddle name: paddlepaddle
version: """ + paddle_version version: """
requirement = var.requirement_build + python_str + var.requirement_run + python_str + paddle_version
)
requirement = (
var.requirement_build + python_str + var.requirement_run + python_str
)
meta_build = var.build + build_name_str meta_build = var.build + build_name_str
meta_str = package_str + meta_build + requirement + var.test + var.about meta_str = package_str + meta_build + requirement + var.test + var.about
build_str = var.build_const + build_var build_str = var.build_const + build_var
...@@ -153,23 +170,28 @@ package: ...@@ -153,23 +170,28 @@ package:
f.write(build_str) f.write(build_str)
def meta_build_linux(var, def meta_build_linux(
python_str, var, python_str, paddle_version, build_var, build_name_str, cuda_str=None
paddle_version, ):
build_var,
build_name_str,
cuda_str=None):
if cuda_str == None: if cuda_str == None:
package_str = """ package_str = (
"""
package: package:
name: paddlepaddle name: paddlepaddle
version: """ + paddle_version version: """
+ paddle_version
)
else: else:
package_str = """ package_str = (
"""
package: package:
name: paddlepaddle-gpu name: paddlepaddle-gpu
version: """ + paddle_version version: """
requirement = var.requirement_build + python_str + var.requirement_run + python_str + paddle_version
)
requirement = (
var.requirement_build + python_str + var.requirement_run + python_str
)
meta_build = var.build + build_name_str meta_build = var.build + build_name_str
meta_str = package_str + meta_build + requirement meta_str = package_str + meta_build + requirement
if not (cuda_str == None): if not (cuda_str == None):
...@@ -186,24 +208,32 @@ package: ...@@ -186,24 +208,32 @@ package:
f.write(build_str) f.write(build_str)
def meta_build_windows(var, def meta_build_windows(
python_str, var, python_str, paddle_version, blt_var, build_name_str, cuda_str=None
paddle_version, ):
blt_var,
build_name_str,
cuda_str=None):
if cuda_str == None: if cuda_str == None:
package_str = """ package_str = (
"""
package: package:
name: paddlepaddle name: paddlepaddle
version: """ + paddle_version version: """
+ paddle_version
)
else: else:
package_str = """ package_str = (
"""
package: package:
name: paddlepaddle-gpu name: paddlepaddle-gpu
version: """ + paddle_version version: """
+ paddle_version
requirement = var.requirement_build + python_str + var.requirement_run_windows + python_str )
requirement = (
var.requirement_build
+ python_str
+ var.requirement_run_windows
+ python_str
)
meta_build = var.build + build_name_str meta_build = var.build + build_name_str
meta_str = package_str + meta_build + requirement meta_str = package_str + meta_build + requirement
...@@ -223,12 +253,17 @@ package: ...@@ -223,12 +253,17 @@ package:
def conda_build(paddle_version, var): def conda_build(paddle_version, var):
sysstr = platform.system() sysstr = platform.system()
if (sysstr == "Windows"): if sysstr == "Windows":
os.system("mkdir paddle") os.system("mkdir paddle")
os.chdir(r"./paddle") os.chdir(r"./paddle")
for i in range(len(var.python_version)): for i in range(len(var.python_version)):
blt_var = var.pip_prefix_windows + var.pip_cpu + paddle_version + var.windows_pip[ blt_var = (
i] + var.pip_end var.pip_prefix_windows
+ var.pip_cpu
+ paddle_version
+ var.windows_pip[i]
+ var.pip_end
)
name = var.py_str[i] + "_cpu_windows" name = var.py_str[i] + "_cpu_windows"
python_str = var.python_version[i] python_str = var.python_version[i]
meta_build_windows(var, python_str, paddle_version, blt_var, name) meta_build_windows(var, python_str, paddle_version, blt_var, name)
...@@ -237,21 +272,38 @@ def conda_build(paddle_version, var): ...@@ -237,21 +272,38 @@ def conda_build(paddle_version, var):
for i in range(len(var.python_version)): for i in range(len(var.python_version)):
for cuda_str in var.cuda_info: for cuda_str in var.cuda_info:
post = cuda_str[2] post = cuda_str[2]
blt_var = var.pip_prefix_windows + var.pip_gpu + paddle_version + post + var.windows_pip[ blt_var = (
i] + var.pip_end var.pip_prefix_windows
+ var.pip_gpu
+ paddle_version
+ post
+ var.windows_pip[i]
+ var.pip_end
)
name = var.py_str[i] + "_gpu_" + cuda_str[1] + "_windows" name = var.py_str[i] + "_gpu_" + cuda_str[1] + "_windows"
cuda_cudnn_str = cuda_str[0] cuda_cudnn_str = cuda_str[0]
python_str = var.python_version[i] python_str = var.python_version[i]
meta_build_windows(var, python_str, paddle_version, blt_var, meta_build_windows(
name, cuda_cudnn_str) var,
python_str,
paddle_version,
blt_var,
name,
cuda_cudnn_str,
)
os.system("conda build .") os.system("conda build .")
elif (sysstr == "Linux"): elif sysstr == "Linux":
os.system("mkdir paddle") os.system("mkdir paddle")
os.chdir(r"./paddle") os.chdir(r"./paddle")
for i in range(len(var.python_version)): for i in range(len(var.python_version)):
build_var = var.pip_prefix_linux + var.pip_cpu + paddle_version + var.linux_pip[ build_var = (
i] + var.pip_end var.pip_prefix_linux
+ var.pip_cpu
+ paddle_version
+ var.linux_pip[i]
+ var.pip_end
)
name = var.py_str[i] + "_cpu_many_linux" name = var.py_str[i] + "_cpu_many_linux"
python_str = var.python_version[i] python_str = var.python_version[i]
meta_build_linux(var, python_str, paddle_version, build_var, name) meta_build_linux(var, python_str, paddle_version, build_var, name)
...@@ -260,23 +312,40 @@ def conda_build(paddle_version, var): ...@@ -260,23 +312,40 @@ def conda_build(paddle_version, var):
for i in range(len(var.python_version)): for i in range(len(var.python_version)):
for cuda_str in var.cuda_info: for cuda_str in var.cuda_info:
post = cuda_str[2] post = cuda_str[2]
build_var = var.pip_prefix_linux + var.pip_gpu + paddle_version + post + var.linux_pip[ build_var = (
i] + var.pip_end var.pip_prefix_linux
+ var.pip_gpu
+ paddle_version
+ post
+ var.linux_pip[i]
+ var.pip_end
)
name = var.py_str[i] + "_gpu_" + cuda_str[1] + "_many_linux" name = var.py_str[i] + "_gpu_" + cuda_str[1] + "_many_linux"
cuda_cudnn_str = cuda_str[0] cuda_cudnn_str = cuda_str[0]
python_str = var.python_version[i] python_str = var.python_version[i]
meta_build_linux(var, python_str, paddle_version, build_var, meta_build_linux(
name, cuda_cudnn_str) var,
python_str,
paddle_version,
build_var,
name,
cuda_cudnn_str,
)
os.system("conda build .") os.system("conda build .")
os.system("cd ..") os.system("cd ..")
elif (sysstr == "Darwin"): elif sysstr == "Darwin":
os.system("mkdir paddle") os.system("mkdir paddle")
os.chdir(r"./paddle") os.chdir(r"./paddle")
for i in range(len(var.python_version)): for i in range(len(var.python_version)):
build_var = var.pip_prefix_linux + var.pip_cpu + paddle_version + var.mac_pip[ build_var = (
i] + var.pip_end var.pip_prefix_linux
+ var.pip_cpu
+ paddle_version
+ var.mac_pip[i]
+ var.pip_end
)
name = var.py_str[i] + "_mac" name = var.py_str[i] + "_mac"
python_str = var.python_version[i] python_str = var.python_version[i]
meta_build_mac(var, python_str, paddle_version, build_var, name) meta_build_mac(var, python_str, paddle_version, build_var, name)
......
...@@ -20,16 +20,17 @@ from model_zoo import resnet_model ...@@ -20,16 +20,17 @@ from model_zoo import resnet_model
class TestBuildCINNPass(DistPassTestBase): class TestBuildCINNPass(DistPassTestBase):
def init(self): def init(self):
self.atol = 0.5 self.atol = 0.5
self.rtol = 0.0 self.rtol = 0.0
def apply_passes(self, main_prog, startup_prog): def apply_passes(self, main_prog, startup_prog):
pass_manager = PassManager([ pass_manager = PassManager(
[
new_pass("build_cinn"), new_pass("build_cinn"),
new_pass("fuse_elewise_add_act"), new_pass("fuse_elewise_add_act"),
]) ]
)
pass_manager.apply([main_prog], [startup_prog]) pass_manager.apply([main_prog], [startup_prog])
print(pass_manager.names) print(pass_manager.names)
......
...@@ -20,16 +20,17 @@ from model_zoo import simple_net ...@@ -20,16 +20,17 @@ from model_zoo import simple_net
class TestBuildCINNPass(DistPassTestBase): class TestBuildCINNPass(DistPassTestBase):
def init(self): def init(self):
self.atol = 0.0 self.atol = 0.0
self.rtol = 0.0 self.rtol = 0.0
def apply_passes(self, main_prog, startup_prog): def apply_passes(self, main_prog, startup_prog):
pass_manager = PassManager([ pass_manager = PassManager(
[
new_pass("build_cinn"), new_pass("build_cinn"),
new_pass("fuse_elewise_add_act"), new_pass("fuse_elewise_add_act"),
]) ]
)
pass_manager.apply([main_prog], [startup_prog]) pass_manager.apply([main_prog], [startup_prog])
op_types = [op.type for op in main_prog.global_block().ops] op_types = [op.type for op in main_prog.global_block().ops]
self.assertTrue('cinn_launch' in op_types) self.assertTrue('cinn_launch' in op_types)
......
...@@ -23,7 +23,6 @@ program_translator = ProgramTranslator() ...@@ -23,7 +23,6 @@ program_translator = ProgramTranslator()
class TestResnetWithPass(unittest.TestCase): class TestResnetWithPass(unittest.TestCase):
def setUp(self): def setUp(self):
self.build_strategy = paddle.static.BuildStrategy() self.build_strategy = paddle.static.BuildStrategy()
self.build_strategy.fuse_elewise_add_act_ops = True self.build_strategy.fuse_elewise_add_act_ops = True
...@@ -48,19 +47,24 @@ class TestResnetWithPass(unittest.TestCase): ...@@ -48,19 +47,24 @@ class TestResnetWithPass(unittest.TestCase):
dy_pre, dy_pre,
st_pre, st_pre,
rtol=1e-05, rtol=1e-05,
err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre)) err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre),
)
np.testing.assert_allclose( np.testing.assert_allclose(
dy_jit_pre, dy_jit_pre,
st_pre, st_pre,
rtol=1e-05, rtol=1e-05,
err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format( err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format(
dy_jit_pre, st_pre)) dy_jit_pre, st_pre
),
)
np.testing.assert_allclose( np.testing.assert_allclose(
predictor_pre, predictor_pre,
st_pre, st_pre,
rtol=1e-05, rtol=1e-05,
err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format( err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format(
predictor_pre, st_pre)) predictor_pre, st_pre
),
)
def test_resnet(self): def test_resnet(self):
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
...@@ -70,7 +74,9 @@ class TestResnetWithPass(unittest.TestCase): ...@@ -70,7 +74,9 @@ class TestResnetWithPass(unittest.TestCase):
dygraph_loss, dygraph_loss,
rtol=1e-05, rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format( err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss)) static_loss, dygraph_loss
),
)
self.verify_predict() self.verify_predict()
def test_in_static_mode_mkldnn(self): def test_in_static_mode_mkldnn(self):
...@@ -83,9 +89,7 @@ class TestResnetWithPass(unittest.TestCase): ...@@ -83,9 +89,7 @@ class TestResnetWithPass(unittest.TestCase):
class TestError(unittest.TestCase): class TestError(unittest.TestCase):
def test_type_error(self): def test_type_error(self):
def foo(x): def foo(x):
out = x + 1 out = x + 1
return out return out
......
...@@ -66,13 +66,13 @@ def func_decorated_by_other_2(): ...@@ -66,13 +66,13 @@ def func_decorated_by_other_2():
class LayerErrorInCompiletime(fluid.dygraph.Layer): class LayerErrorInCompiletime(fluid.dygraph.Layer):
def __init__(self, fc_size=20): def __init__(self, fc_size=20):
super(LayerErrorInCompiletime, self).__init__() super(LayerErrorInCompiletime, self).__init__()
self._linear = fluid.dygraph.Linear(fc_size, fc_size) self._linear = fluid.dygraph.Linear(fc_size, fc_size)
@paddle.jit.to_static( @paddle.jit.to_static(
input_spec=[paddle.static.InputSpec(shape=[20, 20], dtype='float32')]) input_spec=[paddle.static.InputSpec(shape=[20, 20], dtype='float32')]
)
def forward(self, x): def forward(self, x):
y = self._linear(x) y = self._linear(x)
z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int") z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")
...@@ -81,7 +81,6 @@ class LayerErrorInCompiletime(fluid.dygraph.Layer): ...@@ -81,7 +81,6 @@ class LayerErrorInCompiletime(fluid.dygraph.Layer):
class LayerErrorInCompiletime2(fluid.dygraph.Layer): class LayerErrorInCompiletime2(fluid.dygraph.Layer):
def __init__(self): def __init__(self):
super(LayerErrorInCompiletime2, self).__init__() super(LayerErrorInCompiletime2, self).__init__()
...@@ -93,7 +92,7 @@ class LayerErrorInCompiletime2(fluid.dygraph.Layer): ...@@ -93,7 +92,7 @@ class LayerErrorInCompiletime2(fluid.dygraph.Layer):
""" """
NOTE: The next line has a tab. And this test to check the IndentationError when spaces and tabs are mixed. NOTE: The next line has a tab. And this test to check the IndentationError when spaces and tabs are mixed.
A tab here. A tab here.
""" """ # fmt: skip
return return
...@@ -108,7 +107,6 @@ def func_error_in_runtime_with_empty_line(x): ...@@ -108,7 +107,6 @@ def func_error_in_runtime_with_empty_line(x):
class SuggestionErrorTestNet(paddle.nn.Layer): class SuggestionErrorTestNet(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(SuggestionErrorTestNet, self).__init__() super(SuggestionErrorTestNet, self).__init__()
self.inner_net = SuggestionErrorTestNet2() self.inner_net = SuggestionErrorTestNet2()
...@@ -118,11 +116,10 @@ class SuggestionErrorTestNet(paddle.nn.Layer): ...@@ -118,11 +116,10 @@ class SuggestionErrorTestNet(paddle.nn.Layer):
return self.inner_net.forward(x) return self.inner_net.forward(x)
class SuggestionErrorTestNet2(): class SuggestionErrorTestNet2:
def __init__(self): def __init__(self):
super(SuggestionErrorTestNet2, self).__init__() super(SuggestionErrorTestNet2, self).__init__()
self.w = paddle.to_tensor([2.]) self.w = paddle.to_tensor([2.0])
def forward(self, x): def forward(self, x):
out = paddle.matmul(self.w, x) out = paddle.matmul(self.w, x)
...@@ -135,7 +132,6 @@ def func_suggestion_error_in_runtime(x): ...@@ -135,7 +132,6 @@ def func_suggestion_error_in_runtime(x):
class TestFlags(unittest.TestCase): class TestFlags(unittest.TestCase):
def setUp(self): def setUp(self):
self.reset_flags_to_default() self.reset_flags_to_default()
...@@ -144,13 +140,15 @@ class TestFlags(unittest.TestCase): ...@@ -144,13 +140,15 @@ class TestFlags(unittest.TestCase):
# 1. A flag to set whether to open the dygraph2static error reporting module # 1. A flag to set whether to open the dygraph2static error reporting module
os.environ[error.DISABLE_ERROR_ENV_NAME] = str( os.environ[error.DISABLE_ERROR_ENV_NAME] = str(
error.DEFAULT_DISABLE_NEW_ERROR) error.DEFAULT_DISABLE_NEW_ERROR
)
disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 999)) disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 999))
self.assertEqual(disable_error, 0) self.assertEqual(disable_error, 0)
# 2. A flag to set whether to display the simplified error stack # 2. A flag to set whether to display the simplified error stack
os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str( os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(
error.DEFAULT_SIMPLIFY_NEW_ERROR) error.DEFAULT_SIMPLIFY_NEW_ERROR
)
simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 999)) simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 999))
self.assertEqual(simplify_error, 1) self.assertEqual(simplify_error, 1)
...@@ -167,7 +165,6 @@ class TestFlags(unittest.TestCase): ...@@ -167,7 +165,6 @@ class TestFlags(unittest.TestCase):
class TestErrorBase(unittest.TestCase): class TestErrorBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.set_input() self.set_input()
self.set_func() self.set_func()
...@@ -188,20 +185,24 @@ class TestErrorBase(unittest.TestCase): ...@@ -188,20 +185,24 @@ class TestErrorBase(unittest.TestCase):
def set_exception_type(self): def set_exception_type(self):
raise NotImplementedError( raise NotImplementedError(
"Error test should implement set_exception_type") "Error test should implement set_exception_type"
)
def set_message(self): def set_message(self):
raise NotImplementedError("Error test should implement set_message") raise NotImplementedError("Error test should implement set_message")
def reset_flags_to_default(self): def reset_flags_to_default(self):
os.environ[error.DISABLE_ERROR_ENV_NAME] = str( os.environ[error.DISABLE_ERROR_ENV_NAME] = str(
error.DEFAULT_DISABLE_NEW_ERROR) error.DEFAULT_DISABLE_NEW_ERROR
)
os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str( os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(
error.DEFAULT_SIMPLIFY_NEW_ERROR) error.DEFAULT_SIMPLIFY_NEW_ERROR
)
def disable_new_error(self): def disable_new_error(self):
os.environ[error.DISABLE_ERROR_ENV_NAME] = str( os.environ[error.DISABLE_ERROR_ENV_NAME] = str(
1 - error.DEFAULT_DISABLE_NEW_ERROR) 1 - error.DEFAULT_DISABLE_NEW_ERROR
)
def _test_new_error_message(self, new_exception, disable_new_error=0): def _test_new_error_message(self, new_exception, disable_new_error=0):
error_message = str(new_exception) error_message = str(new_exception)
...@@ -242,7 +243,6 @@ class TestErrorBase(unittest.TestCase): ...@@ -242,7 +243,6 @@ class TestErrorBase(unittest.TestCase):
# Situation 1: Call StaticLayer.__call__ to use Dynamic-to-Static # Situation 1: Call StaticLayer.__call__ to use Dynamic-to-Static
class TestErrorStaticLayerCallInCompiletime(TestErrorBase): class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
def set_func(self): def set_func(self):
self.func = func_error_in_compile_time self.func = func_error_in_compile_time
...@@ -255,7 +255,8 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase): ...@@ -255,7 +255,8 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 33, in func_error_in_compile_time'.format( 'File "{}", line 33, in func_error_in_compile_time'.format(
self.filepath), self.filepath
),
'inner_func()', 'inner_func()',
'File "{}", line 26, in inner_func'.format(self.filepath), 'File "{}", line 26, in inner_func'.format(self.filepath),
'def inner_func():', 'def inner_func():',
...@@ -274,8 +275,8 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase): ...@@ -274,8 +275,8 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
class TestErrorStaticLayerCallInCompiletime_2( class TestErrorStaticLayerCallInCompiletime_2(
TestErrorStaticLayerCallInCompiletime): TestErrorStaticLayerCallInCompiletime
):
def set_func(self): def set_func(self):
self.func = func_error_in_compile_time_2 self.func = func_error_in_compile_time_2
...@@ -285,7 +286,8 @@ class TestErrorStaticLayerCallInCompiletime_2( ...@@ -285,7 +286,8 @@ class TestErrorStaticLayerCallInCompiletime_2(
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 44, in func_error_in_compile_time_2'.format( 'File "{}", line 44, in func_error_in_compile_time_2'.format(
self.filepath), self.filepath
),
'def func_error_in_compile_time_2(x):', 'def func_error_in_compile_time_2(x):',
'x = fluid.dygraph.to_variable(x)', 'x = fluid.dygraph.to_variable(x)',
'x = fluid.layers.reshape(x, shape=[1, 2])', 'x = fluid.layers.reshape(x, shape=[1, 2])',
...@@ -295,8 +297,8 @@ class TestErrorStaticLayerCallInCompiletime_2( ...@@ -295,8 +297,8 @@ class TestErrorStaticLayerCallInCompiletime_2(
class TestErrorStaticLayerCallInCompiletime_3( class TestErrorStaticLayerCallInCompiletime_3(
TestErrorStaticLayerCallInCompiletime): TestErrorStaticLayerCallInCompiletime
):
def setUp(self): def setUp(self):
self.reset_flags_to_default() self.reset_flags_to_default()
self.set_func_call() self.set_func_call()
...@@ -309,7 +311,7 @@ class TestErrorStaticLayerCallInCompiletime_3( ...@@ -309,7 +311,7 @@ class TestErrorStaticLayerCallInCompiletime_3(
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 90, in forward'.format(self.filepath), 'File "{}", line 89, in forward'.format(self.filepath),
'@paddle.jit.to_static', '@paddle.jit.to_static',
'def forward(self):', 'def forward(self):',
'self.test_func()', 'self.test_func()',
...@@ -325,7 +327,6 @@ class TestErrorStaticLayerCallInCompiletime_3( ...@@ -325,7 +327,6 @@ class TestErrorStaticLayerCallInCompiletime_3(
class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime): class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
def set_func(self): def set_func(self):
self.func = func_error_in_runtime self.func = func_error_in_runtime
...@@ -335,7 +336,8 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime): ...@@ -335,7 +336,8 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 52, in func_error_in_runtime'.format( 'File "{}", line 52, in func_error_in_runtime'.format(
self.filepath), self.filepath
),
'x = fluid.dygraph.to_variable(x)', 'x = fluid.dygraph.to_variable(x)',
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")', 'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])', 'x = fluid.layers.reshape(x, shape=[1, two])',
...@@ -345,14 +347,14 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime): ...@@ -345,14 +347,14 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime): class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime):
def set_func(self): def set_func(self):
self.func = func_error_in_runtime_with_empty_line self.func = func_error_in_runtime_with_empty_line
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 105, in func_error_in_runtime_with_empty_line'. 'File "{}", line 104, in func_error_in_runtime_with_empty_line'.format(
format(self.filepath), self.filepath
),
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")', 'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])', 'x = fluid.layers.reshape(x, shape=[1, two])',
'<--- HERE', '<--- HERE',
...@@ -362,29 +364,29 @@ class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime): ...@@ -362,29 +364,29 @@ class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime):
# Situation 2: Call ProgramTranslator().get_output(...) to use Dynamic-to-Static # Situation 2: Call ProgramTranslator().get_output(...) to use Dynamic-to-Static
class TestErrorGetOutputInCompiletime(TestErrorStaticLayerCallInCompiletime): class TestErrorGetOutputInCompiletime(TestErrorStaticLayerCallInCompiletime):
def set_func_call(self): def set_func_call(self):
self.func_call = lambda: self.prog_trans.get_output( self.func_call = lambda: self.prog_trans.get_output(
unwrap(self.func), self.input) unwrap(self.func), self.input
)
class TestErrorGetOutputInCompiletime_2(TestErrorStaticLayerCallInCompiletime_2
):
class TestErrorGetOutputInCompiletime_2(
TestErrorStaticLayerCallInCompiletime_2
):
def set_func_call(self): def set_func_call(self):
self.func_call = lambda: self.prog_trans.get_output( self.func_call = lambda: self.prog_trans.get_output(
unwrap(self.func), self.input) unwrap(self.func), self.input
)
class TestErrorGetOutputInRuntime(TestErrorStaticLayerCallInRuntime): class TestErrorGetOutputInRuntime(TestErrorStaticLayerCallInRuntime):
def set_func_call(self): def set_func_call(self):
self.func_call = lambda: self.prog_trans.get_output( self.func_call = lambda: self.prog_trans.get_output(
unwrap(self.func), self.input) unwrap(self.func), self.input
)
class TestJitSaveInCompiletime(TestErrorBase): class TestJitSaveInCompiletime(TestErrorBase):
def setUp(self): def setUp(self):
self.reset_flags_to_default() self.reset_flags_to_default()
self.set_func_call() self.set_func_call()
...@@ -409,7 +411,8 @@ class TestJitSaveInCompiletime(TestErrorBase): ...@@ -409,7 +411,8 @@ class TestJitSaveInCompiletime(TestErrorBase):
def set_func_call(self): def set_func_call(self):
layer = LayerErrorInCompiletime() layer = LayerErrorInCompiletime()
self.func_call = lambda: paddle.jit.save( self.func_call = lambda: paddle.jit.save(
layer, path="./test_dy2stat_error/model") layer, path="./test_dy2stat_error/model"
)
def test_error(self): def test_error(self):
self._test_raise_new_exception() self._test_raise_new_exception()
...@@ -419,21 +422,20 @@ class TestJitSaveInCompiletime(TestErrorBase): ...@@ -419,21 +422,20 @@ class TestJitSaveInCompiletime(TestErrorBase):
class TestSuggestionErrorInRuntime(TestErrorBase): class TestSuggestionErrorInRuntime(TestErrorBase):
def set_func(self): def set_func(self):
self.func = func_suggestion_error_in_runtime self.func = func_suggestion_error_in_runtime
def set_input(self): def set_input(self):
self.input = paddle.to_tensor([2.]) self.input = paddle.to_tensor([2.0])
def set_exception_type(self): def set_exception_type(self):
self.exception_type = ValueError self.exception_type = ValueError
def set_message(self): def set_message(self):
self.expected_message = [ self.expected_message = [
'File "{}", line 118, in forward'.format(self.filepath), 'File "{}", line 116, in forward'.format(self.filepath),
'return self.inner_net.forward(x)', 'return self.inner_net.forward(x)',
'File "{}", line 128, in forward'.format(self.filepath), 'File "{}", line 125, in forward'.format(self.filepath),
'def forward(self, x):', 'def forward(self, x):',
'out = paddle.matmul(self.w, x)', 'out = paddle.matmul(self.w, x)',
'<--- HERE', '<--- HERE',
...@@ -460,7 +462,6 @@ def func_ker_error(x): ...@@ -460,7 +462,6 @@ def func_ker_error(x):
class TestKeyError(unittest.TestCase): class TestKeyError(unittest.TestCase):
def test_key_error(self): def test_key_error(self):
paddle.disable_static() paddle.disable_static()
with self.assertRaises(error.Dy2StKeyError): with self.assertRaises(error.Dy2StKeyError):
...@@ -476,7 +477,6 @@ def NpApiErr(): ...@@ -476,7 +477,6 @@ def NpApiErr():
class TestNumpyApiErr(unittest.TestCase): class TestNumpyApiErr(unittest.TestCase):
def test_numpy_api_err(self): def test_numpy_api_err(self):
with self.assertRaises(TypeError) as e: with self.assertRaises(TypeError) as e:
NpApiErr() NpApiErr()
...@@ -490,11 +490,11 @@ class TestNumpyApiErr(unittest.TestCase): ...@@ -490,11 +490,11 @@ class TestNumpyApiErr(unittest.TestCase):
self.assertIn( self.assertIn(
"values will be changed to variables by dy2static, numpy api can not handle variables", "values will be changed to variables by dy2static, numpy api can not handle variables",
error_message) error_message,
)
class test_set_state_dict_err_layer(paddle.nn.Layer): class test_set_state_dict_err_layer(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(test_set_state_dict_err_layer, self).__init__() super(test_set_state_dict_err_layer, self).__init__()
self.linear = paddle.nn.Linear(5, 2) self.linear = paddle.nn.Linear(5, 2)
...@@ -514,11 +514,10 @@ class test_set_state_dict_err_layer(paddle.nn.Layer): ...@@ -514,11 +514,10 @@ class test_set_state_dict_err_layer(paddle.nn.Layer):
class TestSetStateDictErr(unittest.TestCase): class TestSetStateDictErr(unittest.TestCase):
def test_set_state_dict_err(self): def test_set_state_dict_err(self):
with self.assertRaises(ValueError) as e: with self.assertRaises(ValueError) as e:
layer = test_set_state_dict_err_layer() layer = test_set_state_dict_err_layer()
x = paddle.to_tensor([1., 2., 3., 4., 5.]) x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0, 5.0])
y = layer(x) y = layer(x)
new_exception = e.exception new_exception = e.exception
...@@ -530,7 +529,8 @@ class TestSetStateDictErr(unittest.TestCase): ...@@ -530,7 +529,8 @@ class TestSetStateDictErr(unittest.TestCase):
self.assertIn( self.assertIn(
"This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'.", "This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'.",
error_message) error_message,
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,8 +15,19 @@ ...@@ -15,8 +15,19 @@
import sys import sys
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import (
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, ORIGI_INFO, OriginInfo, attach_origin_info, create_and_update_origin_info_map, gast, inspect, unwrap DygraphToStaticAst,
)
from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
Location,
ORIGI_INFO,
OriginInfo,
attach_origin_info,
create_and_update_origin_info_map,
gast,
inspect,
unwrap,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
...@@ -27,7 +38,6 @@ def simple_func(x): ...@@ -27,7 +38,6 @@ def simple_func(x):
def nested_func(x): def nested_func(x):
def f1(a): def f1(a):
return a return a
...@@ -47,7 +57,6 @@ def decorated_func2(x): ...@@ -47,7 +57,6 @@ def decorated_func2(x):
class TestOriginInfo(unittest.TestCase): class TestOriginInfo(unittest.TestCase):
def setUp(self): def setUp(self):
self.set_test_func() self.set_test_func()
self.dygraph_func = unwrap(self.func) self.dygraph_func = unwrap(self.func)
...@@ -55,8 +64,9 @@ class TestOriginInfo(unittest.TestCase): ...@@ -55,8 +64,9 @@ class TestOriginInfo(unittest.TestCase):
self.source_code = inspect.getsource(self.dygraph_func) self.source_code = inspect.getsource(self.dygraph_func)
lines, self.start_lineno = inspect.getsourcelines(self.dygraph_func) lines, self.start_lineno = inspect.getsourcelines(self.dygraph_func)
lines = [line.strip("\n") for line in lines] lines = [line.strip("\n") for line in lines]
self.lines = [line for line in lines self.lines = [
if line != ""] # Delete empty lines line for line in lines if line != ""
] # Delete empty lines
self.set_static_lineno() self.set_static_lineno()
self.set_dygraph_info() self.set_dygraph_info()
...@@ -77,8 +87,9 @@ class TestOriginInfo(unittest.TestCase): ...@@ -77,8 +87,9 @@ class TestOriginInfo(unittest.TestCase):
def set_origin_info_list(self, dygraph_ast): def set_origin_info_list(self, dygraph_ast):
assert isinstance(dygraph_ast, gast.Module) assert isinstance(dygraph_ast, gast.Module)
self.transformed_node_list = [ self.transformed_node_list = [
dygraph_ast.body[0], dygraph_ast.body[0].body[0], dygraph_ast.body[0],
dygraph_ast.body[0].body[1] dygraph_ast.body[0].body[0],
dygraph_ast.body[0].body[1],
] ]
def _get_OriginInfo_map(self): def _get_OriginInfo_map(self):
...@@ -91,8 +102,9 @@ class TestOriginInfo(unittest.TestCase): ...@@ -91,8 +102,9 @@ class TestOriginInfo(unittest.TestCase):
# step3 # step3
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
info_map = create_and_update_origin_info_map(dygraph_ast, info_map = create_and_update_origin_info_map(
self.static_func) dygraph_ast, self.static_func
)
return info_map return info_map
def test_origin_info_map(self): def test_origin_info_map(self):
...@@ -115,9 +127,12 @@ class TestOriginInfo(unittest.TestCase): ...@@ -115,9 +127,12 @@ class TestOriginInfo(unittest.TestCase):
code = self.lines[line_idx] code = self.lines[line_idx]
origin_info = OriginInfo( origin_info = OriginInfo(
Location(self.dygraph_filepath, dy_lineno, dy_col_offset), Location(self.dygraph_filepath, dy_lineno, dy_col_offset),
self.dy_func_name[i], code) self.dy_func_name[i],
self.assertEqual(str(origin_info_map[staic_loc.line_location]), code,
str(origin_info)) )
self.assertEqual(
str(origin_info_map[staic_loc.line_location]), str(origin_info)
)
def test_attach_origin_info(self): def test_attach_origin_info(self):
dygraph_ast = gast.parse(self.source_code) dygraph_ast = gast.parse(self.source_code)
...@@ -144,7 +159,6 @@ class TestOriginInfo(unittest.TestCase): ...@@ -144,7 +159,6 @@ class TestOriginInfo(unittest.TestCase):
class TestOriginInfoWithNestedFunc(TestOriginInfo): class TestOriginInfoWithNestedFunc(TestOriginInfo):
def set_test_func(self): def set_test_func(self):
self.func = nested_func self.func = nested_func
...@@ -154,23 +168,26 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): ...@@ -154,23 +168,26 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo):
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 5 self.line_num = 5
self.line_index_list = [0, 1, 2, 3, 4] self.line_index_list = [0, 1, 2, 3, 4]
self.dy_rel_lineno_list = [0, 2, 3, 5, 6] self.dy_rel_lineno_list = [0, 1, 2, 4, 5]
self.dy_abs_col_offset = [0, 4, 8, 4, 4] self.dy_abs_col_offset = [0, 4, 8, 4, 4]
self.dy_func_name = [self.dygraph_func.__name__] + \ self.dy_func_name = (
["f1"] * 2 + \ [self.dygraph_func.__name__]
[self.dygraph_func.__name__] * 2 + ["f1"] * 2
+ [self.dygraph_func.__name__] * 2
)
def set_origin_info_list(self, dygraph_ast): def set_origin_info_list(self, dygraph_ast):
assert isinstance(dygraph_ast, gast.Module) assert isinstance(dygraph_ast, gast.Module)
self.transformed_node_list = [ self.transformed_node_list = [
dygraph_ast.body[0], dygraph_ast.body[0].body[0], dygraph_ast.body[0],
dygraph_ast.body[0].body[0].body[0], dygraph_ast.body[0].body[1], dygraph_ast.body[0].body[0],
dygraph_ast.body[0].body[2] dygraph_ast.body[0].body[0].body[0],
dygraph_ast.body[0].body[1],
dygraph_ast.body[0].body[2],
] ]
class TestOriginInfoWithDecoratedFunc(TestOriginInfo): class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
def set_test_func(self): def set_test_func(self):
self.func = decorated_func self.func = decorated_func
...@@ -205,7 +222,6 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): ...@@ -205,7 +222,6 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
def set_test_func(self): def set_test_func(self):
self.func = decorated_func2 self.func = decorated_func2
......
...@@ -19,7 +19,6 @@ from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase ...@@ -19,7 +19,6 @@ from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase
class FusionGroupPaddingRNNTest(PaddingRNNTestBase): class FusionGroupPaddingRNNTest(PaddingRNNTestBase):
def set_customed_config(self): def set_customed_config(self):
self.build_strategy.enable_auto_fusion = True self.build_strategy.enable_auto_fusion = True
......
...@@ -25,7 +25,6 @@ import tempfile ...@@ -25,7 +25,6 @@ import tempfile
class TestPassBuilder(unittest.TestCase): class TestPassBuilder(unittest.TestCase):
def check_network_convergence(self, use_cuda, build_strategy=None): def check_network_convergence(self, use_cuda, build_strategy=None):
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
main = fluid.Program() main = fluid.Program()
...@@ -47,20 +46,22 @@ class TestPassBuilder(unittest.TestCase): ...@@ -47,20 +46,22 @@ class TestPassBuilder(unittest.TestCase):
feed_dict = {'image': image, 'label': label} feed_dict = {'image': image, 'label': label}
train_cp = compiler.CompiledProgram(main).with_data_parallel( train_cp = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy
)
test_cp = compiler.CompiledProgram(test_program).with_data_parallel( test_cp = compiler.CompiledProgram(test_program).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
share_vars_from=train_cp) share_vars_from=train_cp,
)
for i in range(5): for i in range(5):
_ = exe.run(train_cp, fetch_list=[loss.name], feed=feed_dict) _ = exe.run(train_cp, fetch_list=[loss.name], feed=feed_dict)
test_loss, = exe.run(test_cp, (test_loss,) = exe.run(
fetch_list=[loss.name], test_cp, fetch_list=[loss.name], feed=feed_dict
feed=feed_dict) )
train_loss, = exe.run(train_cp, (train_loss,) = exe.run(
fetch_list=[loss.name], train_cp, fetch_list=[loss.name], feed=feed_dict
feed=feed_dict) )
avg_test_loss_val = np.array(test_loss).mean() avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)): if math.isnan(float(avg_test_loss_val)):
...@@ -70,32 +71,38 @@ class TestPassBuilder(unittest.TestCase): ...@@ -70,32 +71,38 @@ class TestPassBuilder(unittest.TestCase):
if math.isnan(float(avg_train_loss_val)): if math.isnan(float(avg_train_loss_val)):
sys.exit("got NaN loss, training failed.") sys.exit("got NaN loss, training failed.")
np.testing.assert_allclose(train_loss, np.testing.assert_allclose(
train_loss,
test_loss, test_loss,
rtol=1e-05, rtol=1e-05,
atol=1e-08, atol=1e-08,
err_msg='Train loss: ' + err_msg='Train loss: '
str(train_loss) + '\n Test loss:' + + str(train_loss)
str(test_loss)) + '\n Test loss:'
+ str(test_loss),
)
def test_parallel_testing_with_new_strategy(self): def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
self.assertFalse(build_strategy.fuse_elewise_add_act_ops) self.assertFalse(build_strategy.fuse_elewise_add_act_ops)
build_strategy.fuse_elewise_add_act_ops = True build_strategy.fuse_elewise_add_act_ops = True
#FIXME: currently fuse_elewise_add_act_ops not compatible with below options # FIXME: currently fuse_elewise_add_act_ops not compatible with below options
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
pass_builder = build_strategy._finalize_strategy_and_create_passes() pass_builder = build_strategy._finalize_strategy_and_create_passes()
self.assertTrue("fuse_elewise_add_act_pass" in self.assertTrue(
[p.type() for p in pass_builder.all_passes()]) "fuse_elewise_add_act_pass"
in [p.type() for p in pass_builder.all_passes()]
)
origin_len = len(pass_builder.all_passes()) origin_len = len(pass_builder.all_passes())
viz_pass = pass_builder.append_pass("graph_viz_pass") viz_pass = pass_builder.append_pass("graph_viz_pass")
self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
pass_builder.insert_pass(len(pass_builder.all_passes()), pass_builder.insert_pass(
"graph_viz_pass") len(pass_builder.all_passes()), "graph_viz_pass"
)
self.assertEqual(origin_len + 2, len(pass_builder.all_passes())) self.assertEqual(origin_len + 2, len(pass_builder.all_passes()))
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
...@@ -106,7 +113,8 @@ class TestPassBuilder(unittest.TestCase): ...@@ -106,7 +113,8 @@ class TestPassBuilder(unittest.TestCase):
self.check_network_convergence( self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(), use_cuda=core.is_compiled_with_cuda(),
build_strategy=build_strategy) build_strategy=build_strategy,
)
try: try:
os.stat(graph_viz_path) os.stat(graph_viz_path)
except os.error: except os.error:
......
...@@ -865,18 +865,18 @@ class TestOverlapAddException(unittest.TestCase): ...@@ -865,18 +865,18 @@ class TestOverlapAddException(unittest.TestCase):
@parameterize( @parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided'), (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided'),
[ [
('test_1d_input', rand_x(1, np.float64, shape=[160000]), ('test_1d_input', rand_x(1, np.float64, shape=[160000]), 512,
512, None, None, get_window('hann', 512), True, 'reflect', False, True), None, None, get_window('hann', 512), True, 'reflect', False, True),
('test_2d_input', rand_x(2, np.float64, shape=[1, 160000]), ('test_2d_input', rand_x(2, np.float64, shape=[1, 160000]), 512,
512, None, None, get_window('hann', 512), True, 'reflect', False, True), None, None, get_window('hann', 512), True, 'reflect', False, True),
('test_hop_length', rand_x(2, np.float64, shape=[1, 160000]), ('test_hop_length', rand_x(2, np.float64, shape=[1, 160000]), 512,
512, 255, None, get_window('hann', 512), True, 'reflect', False, True), 255, None, get_window('hann', 512), True, 'reflect', False, True),
('test_win_length', rand_x(2, np.float64, shape=[1, 160000]), ('test_win_length', rand_x(2, np.float64, shape=[1, 160000]), 512,
512, 255, 499, get_window('hann', 499), True, 'reflect', False, True), 255, 499, get_window('hann', 499), True, 'reflect', False, True),
('test_window', rand_x(2, np.float64, shape=[1, 160000]), ('test_window', rand_x(2, np.float64, shape=[1, 160000]), 512,
512, None, None, None, True, 'reflect', False, True), None, None, None, True, 'reflect', False, True),
('test_center', rand_x(2, np.float64, shape=[1, 160000]), ('test_center', rand_x(2, np.float64, shape=[1, 160000]), 512,
512, None, None, None, False, 'reflect', False, True), None, None, None, False, 'reflect', False, True),
]) # fmt: skip ]) # fmt: skip
class TestStft(unittest.TestCase): class TestStft(unittest.TestCase):
def test_stft(self): def test_stft(self):
...@@ -917,22 +917,22 @@ class TestStft(unittest.TestCase): ...@@ -917,22 +917,22 @@ class TestStft(unittest.TestCase):
@parameterize( @parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided', 'expect_exception'), (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'pad_mode', 'normalized', 'onesided', 'expect_exception'),
[ [
('test_dims', rand_x(1, np.float64, shape=[1, 2, 3]), ('test_dims', rand_x(1, np.float64, shape=[1, 2, 3]), 512,
512, None, None, None, True, 'reflect', False, True, AssertionError), None, None, None, True, 'reflect', False, True, AssertionError),
('test_hop_length', rand_x(1, np.float64, shape=[16000]), ('test_hop_length', rand_x(1, np.float64, shape=[16000]), 512,
512, 0, None, None, True, 'reflect', False, True, AssertionError), 0, None, None, True, 'reflect', False, True, AssertionError),
('test_nfft1', rand_x(1, np.float64, shape=[16000]), ('test_nfft1', rand_x(1, np.float64, shape=[16000]), 0,
0, None, None, None, True, 'reflect', False, True, AssertionError), None, None, None, True, 'reflect', False, True, AssertionError),
('test_nfft2', rand_x(1, np.float64, shape=[16000]), ('test_nfft2', rand_x(1, np.float64, shape=[16000]), 16001,
16001, None, None, None, True, 'reflect', False, True, AssertionError), None, None, None, True, 'reflect', False, True, AssertionError),
('test_win_length', rand_x(1, np.float64, shape=[16000]), ('test_win_length', rand_x(1, np.float64, shape=[16000]), 512,
512, None, 0, None, True, 'reflect', False, True, AssertionError), None, 0, None, True, 'reflect', False, True, AssertionError),
('test_win_length', rand_x(1, np.float64, shape=[16000]), ('test_win_length', rand_x(1, np.float64, shape=[16000]), 512,
512, None, 513, None, True, 'reflect', False, True, AssertionError), None, 513, None, True, 'reflect', False, True, AssertionError),
('test_pad_mode', rand_x(1, np.float64, shape=[16000]), ('test_pad_mode', rand_x(1, np.float64, shape=[16000]), 512,
512, None, None, None, True, 'nonsense', False, True, AssertionError), None, None, None, True, 'nonsense', False, True, AssertionError),
('test_complex_onesided', rand_x(1, np.float64, shape=[16000], complex=True), ('test_complex_onesided', rand_x(1, np.float64, shape=[16000], complex=True), 512,
512, None, None, None, False, 'reflect', False, True, AssertionError), None, None, None, False, 'reflect', False, True, AssertionError),
]) # fmt: skip ]) # fmt: skip
class TestStftException(unittest.TestCase): class TestStftException(unittest.TestCase):
def test_stft(self): def test_stft(self):
...@@ -959,20 +959,20 @@ class TestStftException(unittest.TestCase): ...@@ -959,20 +959,20 @@ class TestStftException(unittest.TestCase):
@parameterize( @parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex'), (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex'),
[ [
('test_2d_input', rand_x(2, np.float64, shape=[257, 471], complex=True), ('test_2d_input', rand_x(2, np.float64, shape=[257, 471], complex=True), 512,
512, None, None, get_window('hann', 512), True, False, True, None, False), None, None, get_window('hann', 512), True, False, True, None, False),
('test_3d_input', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_3d_input', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, None, get_window('hann', 512), True, False, True, None, False), None, None, get_window('hann', 512), True, False, True, None, False),
('test_hop_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_hop_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, 99, None, get_window('hann', 512), True, False, True, None, False), 99, None, get_window('hann', 512), True, False, True, None, False),
('test_win_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_win_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, 99, 299, get_window('hann', 299), True, False, True, None, False), 99, 299, get_window('hann', 299), True, False, True, None, False),
('test_window', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_window', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, None, None, True, False, True, None, False), None, None, None, True, False, True, None, False),
('test_center', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_center', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, None, None, False, False, True, None, False), None, None, None, False, False, True, None, False),
('test_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_length', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, None, None, False, False, True, 1888, False), None, None, None, False, False, True, 1888, False),
]) # fmt: skip ]) # fmt: skip
class TestIstft(unittest.TestCase): class TestIstft(unittest.TestCase):
def test_istft(self): def test_istft(self):
...@@ -1013,30 +1013,30 @@ class TestIstft(unittest.TestCase): ...@@ -1013,30 +1013,30 @@ class TestIstft(unittest.TestCase):
@parameterize( @parameterize(
(TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex', 'expect_exception'), (TEST_CASE_NAME, 'x', 'n_fft', 'hop_length', 'win_length', 'window', 'center', 'normalized', 'onesided', 'length', 'return_complex', 'expect_exception'),
[ [
('test_dims', rand_x(4, np.float64, shape=[1, 2, 3, 4], complex=True), ('test_dims', rand_x(4, np.float64, shape=[1, 2, 3, 4], complex=True), 512,
512, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError), None, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_n_fft', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_n_fft', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 257,
257, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError), None, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_hop_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_hop_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, 0, None, get_window('hann', 512), True, False, True, None, False, AssertionError), 0, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_hop_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_hop_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, 513, None, get_window('hann', 512), True, False, True, None, False, AssertionError), 513, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_win_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_win_length1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, 0, get_window('hann', 512), True, False, True, None, False, AssertionError), None, 0, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_win_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_win_length2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, 513, get_window('hann', 512), True, False, True, None, False, AssertionError), None, 513, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_onesided1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_onesided1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 20,
20, None, None, get_window('hann', 512), True, False, True, None, False, AssertionError), None, None, get_window('hann', 512), True, False, True, None, False, AssertionError),
('test_onesided2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_onesided2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 256,
256, None, None, None, True, False, False, None, False, AssertionError), None, None, None, True, False, False, None, False, AssertionError),
('test_window', rand_x(3, np.float64, shape=[1, 512, 471], complex=True), ('test_window', rand_x(3, np.float64, shape=[1, 512, 471], complex=True), 512,
512, None, 511, get_window('hann', 512), True, False, False, None, False, AssertionError), None, 511, get_window('hann', 512), True, False, False, None, False, AssertionError),
('test_return_complex1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_return_complex1', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, None, get_window('hann', 512), True, False, True, None, True, AssertionError), None, None, get_window('hann', 512), True, False, True, None, True, AssertionError),
('test_return_complex2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_return_complex2', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, None, None, rand_x(1, np.float64, shape=[512], complex=True), True, False, True, None, False, AssertionError), None, None, rand_x(1, np.float64, shape=[512], complex=True), True, False, True, None, False, AssertionError),
('test_NOLA', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), ('test_NOLA', rand_x(3, np.float64, shape=[1, 257, 471], complex=True), 512,
512, 512, None, get_window('hann', 512), True, False, True, None, False, ValueError), 512, None, get_window('hann', 512), True, False, True, None, False, ValueError),
]) # fmt: skip ]) # fmt: skip
class TestIstftException(unittest.TestCase): class TestIstftException(unittest.TestCase):
def test_istft(self): def test_istft(self):
......
...@@ -26,10 +26,12 @@ def strToSecond(strTime): ...@@ -26,10 +26,12 @@ def strToSecond(strTime):
def getUsefulBuildTimeFile(filename): def getUsefulBuildTimeFile(filename):
os.system( os.system(
"grep -Po -- '-o .*' %s | grep ' elapsed' | grep -P -v '0:00.* elapse' > %s/tools/analysis_build_time" "grep -Po -- '-o .*' %s | grep ' elapsed' | grep -P -v '0:00.* elapse' > %s/tools/analysis_build_time"
% (filename, root_path)) % (filename, root_path)
)
os.system( os.system(
"grep -v -- '-o .*' %s |grep ' elapse' | grep -P -v '0:00.* elapse' >> %s/tools/analysis_build_time" "grep -v -- '-o .*' %s |grep ' elapse' | grep -P -v '0:00.* elapse' >> %s/tools/analysis_build_time"
% (filename, root_path)) % (filename, root_path)
)
def analysisBuildTime(): def analysisBuildTime():
...@@ -45,19 +47,24 @@ def analysisBuildTime(): ...@@ -45,19 +47,24 @@ def analysisBuildTime():
buildFile = line.split(', ')[0].split(' ')[1] buildFile = line.split(', ')[0].split(' ')[1]
buildTime = line.split(', ')[1].split('elapsed')[0].strip() buildTime = line.split(', ')[1].split('elapsed')[0].strip()
secondTime = strToSecond(buildTime) secondTime = strToSecond(buildTime)
os.system("echo %s, %s >> %s/tools/tempbuildTime.txt" % os.system(
(buildFile, secondTime, root_path)) "echo %s, %s >> %s/tools/tempbuildTime.txt"
% (buildFile, secondTime, root_path)
)
else: else:
buildTime = line.split(', ')[1].split('elapsed')[0].strip() buildTime = line.split(', ')[1].split('elapsed')[0].strip()
secondTime = strToSecond(buildTime) secondTime = strToSecond(buildTime)
if secondTime > 30: if secondTime > 30:
os.system("echo %s, %s >> %s/tools/tempbuildTime.txt" % os.system(
(line, secondTime, root_path)) "echo %s, %s >> %s/tools/tempbuildTime.txt"
% (line, secondTime, root_path)
)
except ValueError: except ValueError:
print(line) print(line)
os.system( os.system(
'sort -n -k 2 -r %s/tools/tempbuildTime.txt > %s/tools/buildTime.txt' % 'sort -n -k 2 -r %s/tools/tempbuildTime.txt > %s/tools/buildTime.txt'
(root_path, root_path)) % (root_path, root_path)
)
analysisBuildTime() analysisBuildTime()
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
def is_manylinux1_compatible(): def is_manylinux1_compatible():
# Only Linux, and only x86-64 / i686 # Only Linux, and only x86-64 / i686
from distutils.util import get_platform from distutils.util import get_platform
if get_platform() not in ["linux-x86_64", "linux-i686"]: if get_platform() not in ["linux-x86_64", "linux-i686"]:
return False return False
# Check for presence of _manylinux module # Check for presence of _manylinux module
try: try:
import _manylinux import _manylinux
return bool(_manylinux.manylinux1_compatible) return bool(_manylinux.manylinux1_compatible)
except (ImportError, AttributeError): except (ImportError, AttributeError):
# Fall through to heuristic check below # Fall through to heuristic check below
...@@ -62,9 +64,10 @@ def have_compatible_glibc(major, minimum_minor): ...@@ -62,9 +64,10 @@ def have_compatible_glibc(major, minimum_minor):
import sys import sys
if is_manylinux1_compatible(): if is_manylinux1_compatible():
print("%s is manylinux1 compatible" % (sys.executable, )) print("%s is manylinux1 compatible" % (sys.executable,))
sys.exit(0) sys.exit(0)
else: else:
print("%s is NOT manylinux1 compatible" % (sys.executable, )) print("%s is NOT manylinux1 compatible" % (sys.executable,))
sys.exit(1) sys.exit(1)
...@@ -21,22 +21,24 @@ import sys ...@@ -21,22 +21,24 @@ import sys
print("Testing SSL certificate checking for Python:", sys.version) print("Testing SSL certificate checking for Python:", sys.version)
if (sys.version_info[:2] < (2, 7) or sys.version_info[:2] < (3, 4)): if sys.version_info[:2] < (2, 7) or sys.version_info[:2] < (3, 4):
print("This version never checks SSL certs; skipping tests") print("This version never checks SSL certs; skipping tests")
sys.exit(0) sys.exit(0)
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
from urllib.request import urlopen from urllib.request import urlopen
EXC = OSError EXC = OSError
else: else:
from urllib import urlopen from urllib import urlopen
EXC = IOError EXC = IOError
print("Connecting to %s should work" % (GOOD_SSL, )) print("Connecting to %s should work" % (GOOD_SSL,))
urlopen(GOOD_SSL) urlopen(GOOD_SSL)
print("...it did, yay.") print("...it did, yay.")
print("Connecting to %s should fail" % (BAD_SSL, )) print("Connecting to %s should fail" % (BAD_SSL,))
try: try:
urlopen(BAD_SSL) urlopen(BAD_SSL)
# If we get here then we failed: # If we get here then we failed:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册