未验证 提交 ea4182d7 编写于 作者: M megemini 提交者: GitHub

[Change] 利用multiprocessing对xdoctester进行环境隔离 (#56400)

* [Change] make xdoctester multiprocessing

* [Add] add timeout directive

* [Fix] fix code-block

* [Fix] add ubelt requirements

* [Fix] patch xdoctest in __init__

* [Fix] codestyle

* [Change] restore metric.py
上级 7e7e2b7b
...@@ -18,3 +18,4 @@ parameterized ...@@ -18,3 +18,4 @@ parameterized
wandb>=0.13 wandb>=0.13
xlsxwriter==3.0.9 xlsxwriter==3.0.9
xdoctest xdoctest
ubelt # just for xdoctest
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import collections import dataclasses
import inspect import inspect
import logging import logging
import os import os
...@@ -42,20 +42,17 @@ API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec' ...@@ -42,20 +42,17 @@ API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec'
TEST_TIMEOUT = 10 TEST_TIMEOUT = 10
TestResult = collections.namedtuple( @dataclasses.dataclass
"TestResult", class TestResult:
( name: str
"name", nocode: bool = False
"nocode", passed: bool = False
"passed", skipped: bool = False
"skipped", failed: bool = False
"failed", timeout: bool = False
"time", time: float = float('inf')
"test_msg", test_msg: str = ""
"extra_info", extra_info: str = ""
),
defaults=(None, False, False, False, False, -1, "", None),
)
class DocTester: class DocTester:
...@@ -76,18 +73,20 @@ class DocTester: ...@@ -76,18 +73,20 @@ class DocTester:
If the `style` is set to `google` and `target` is set to `codeblock`, we should implement/overwrite `ensemble_docstring` method, If the `style` is set to `google` and `target` is set to `codeblock`, we should implement/overwrite `ensemble_docstring` method,
where ensemble the codeblock into a docstring with a `Examples:` and some indents as least. where ensemble the codeblock into a docstring with a `Examples:` and some indents as least.
directives(list[str]): `DocTester` hold the default directives, we can/should replace them with method `convert_directive`. directives(list[str]): `DocTester` hold the default directives, we can/should replace them with method `convert_directive`.
For example:
``` text
# doctest: +SKIP
# doctest: +REQUIRES(env:CPU)
# doctest: +REQUIRES(env:GPU)
# doctest: +REQUIRES(env:XPU)
# doctest: +REQUIRES(env:DISTRIBUTED)
# doctest: +REQUIRES(env:GPU, env:XPU)
```
""" """
style = 'google' style = 'google'
target = 'docstring' target = 'docstring'
directives = [ directives = None
"# doctest: +SKIP",
"# doctest: +REQUIRES(env:CPU)",
"# doctest: +REQUIRES(env:GPU)",
"# doctest: +REQUIRES(env:XPU)",
"# doctest: +REQUIRES(env:DISTRIBUTED)",
"# doctest: +REQUIRES(env:GPU, env:XPU)",
]
def ensemble_docstring(self, codeblock: str) -> str: def ensemble_docstring(self, codeblock: str) -> str:
"""Ensemble a cleaned codeblock into a docstring. """Ensemble a cleaned codeblock into a docstring.
......
...@@ -24,10 +24,13 @@ for example, you can run cpu version testing like this: ...@@ -24,10 +24,13 @@ for example, you can run cpu version testing like this:
import functools import functools
import logging import logging
import multiprocessing
import os import os
import platform import platform
import queue
import re import re
import sys import sys
import threading
import time import time
import typing import typing
...@@ -168,9 +171,62 @@ def _patch_float_precision(digits): ...@@ -168,9 +171,62 @@ def _patch_float_precision(digits):
checker.check_output = check_output checker.check_output = check_output
class Directive:
"""Base class of global direvtives just for `xdoctest`."""
pattern: typing.Pattern
def parse_directive(self, docstring: str) -> typing.Tuple[str, typing.Any]:
pass
class TimeoutDirective(Directive):
pattern = re.compile(
r"""
(?:
(?:
\s*\>{3}\s*\#\s*x?doctest\:\s*
)
(?P<op>[\+\-])
(?:
TIMEOUT
)
\(
(?P<time>\d+)
\)
(?:
\s*?
)
)
""",
re.X | re.S,
)
def __init__(self, timeout):
self._timeout = timeout
def parse_directive(self, docstring):
match_obj = self.pattern.search(docstring)
if match_obj is not None:
op_time = match_obj.group('time')
match_start = match_obj.start()
match_end = match_obj.end()
return (
(docstring[:match_start] + '\n' + docstring[match_end:]),
float(op_time),
)
return docstring, float(self._timeout)
class Xdoctester(DocTester): class Xdoctester(DocTester):
"""A Xdoctest doctester.""" """A Xdoctest doctester."""
directives: typing.Dict[str, typing.Tuple[typing.Type[Directive], ...]] = {
'timeout': (TimeoutDirective, TEST_TIMEOUT)
}
def __init__( def __init__(
self, self,
debug=False, debug=False,
...@@ -180,8 +236,8 @@ class Xdoctester(DocTester): ...@@ -180,8 +236,8 @@ class Xdoctester(DocTester):
verbose=2, verbose=2,
patch_global_state=True, patch_global_state=True,
patch_tensor_place=True, patch_tensor_place=True,
patch_float_precision=True, patch_float_precision=5,
patch_float_digits=5, use_multiprocessing=True,
**config, **config,
): ):
self.debug = debug self.debug = debug
...@@ -192,14 +248,13 @@ class Xdoctester(DocTester): ...@@ -192,14 +248,13 @@ class Xdoctester(DocTester):
self.verbose = verbose self.verbose = verbose
self.config = {**XDOCTEST_CONFIG, **(config or {})} self.config = {**XDOCTEST_CONFIG, **(config or {})}
if patch_global_state: self._patch_global_state = patch_global_state
_patch_global_state(self.debug, self.verbose) self._patch_tensor_place = patch_tensor_place
self._patch_float_precision = patch_float_precision
self._use_multiprocessing = use_multiprocessing
if patch_tensor_place: # patch xdoctest before `xdoctest.core.parse_docstr_examples`
_patch_tensor_place() self._patch_xdoctest()
if patch_float_precision:
_patch_float_precision(patch_float_digits)
self.docstring_parser = functools.partial( self.docstring_parser = functools.partial(
xdoctest.core.parse_docstr_examples, style=self.style xdoctest.core.parse_docstr_examples, style=self.style
...@@ -216,6 +271,28 @@ class Xdoctester(DocTester): ...@@ -216,6 +271,28 @@ class Xdoctester(DocTester):
self.directive_prefix = 'xdoctest' self.directive_prefix = 'xdoctest'
def _patch_xdoctest(self):
if self._patch_global_state:
_patch_global_state(self.debug, self.verbose)
if self._patch_tensor_place:
_patch_tensor_place()
if self._patch_float_precision is not None:
_patch_float_precision(self._patch_float_precision)
def _parse_directive(
self, docstring: str
) -> typing.Tuple[str, typing.Dict[str, Directive]]:
directives = {}
for name, directive_cls in self.directives.items():
docstring, direct = directive_cls[0](
*directive_cls[1:]
).parse_directive(docstring)
directives[name] = direct
return docstring, directives
def convert_directive(self, docstring: str) -> str: def convert_directive(self, docstring: str) -> str:
"""Replace directive prefix with xdoctest""" """Replace directive prefix with xdoctest"""
return self.directive_pattern.sub(self.directive_prefix, docstring) return self.directive_pattern.sub(self.directive_prefix, docstring)
...@@ -253,12 +330,31 @@ class Xdoctester(DocTester): ...@@ -253,12 +330,31 @@ class Xdoctester(DocTester):
def run(self, api_name: str, docstring: str) -> typing.List[TestResult]: def run(self, api_name: str, docstring: str) -> typing.List[TestResult]:
"""Run the xdoctest with a docstring.""" """Run the xdoctest with a docstring."""
# parse global directive
docstring, directives = self._parse_directive(docstring)
# extract xdoctest examples
examples_to_test, examples_nocode = self._extract_examples( examples_to_test, examples_nocode = self._extract_examples(
api_name, docstring api_name, docstring, **directives
) )
return self._execute_xdoctest(examples_to_test, examples_nocode)
def _extract_examples(self, api_name, docstring): # run xdoctest
try:
result = self._execute_xdoctest(
examples_to_test, examples_nocode, **directives
)
except queue.Empty:
result = [
TestResult(
name=api_name,
timeout=True,
time=directives.get('timeout', TEST_TIMEOUT),
)
]
return result
def _extract_examples(self, api_name, docstring, **directives):
"""Extract code block examples from docstring.""" """Extract code block examples from docstring."""
examples_to_test = {} examples_to_test = {}
examples_nocode = {} examples_nocode = {}
...@@ -281,8 +377,40 @@ class Xdoctester(DocTester): ...@@ -281,8 +377,40 @@ class Xdoctester(DocTester):
return examples_to_test, examples_nocode return examples_to_test, examples_nocode
def _execute_xdoctest(self, examples_to_test, examples_nocode): def _execute_xdoctest(
self, examples_to_test, examples_nocode, **directives
):
if self._use_multiprocessing:
_ctx = multiprocessing.get_context('spawn')
result_queue = _ctx.Queue()
exec_processer = functools.partial(_ctx.Process, daemon=True)
else:
result_queue = queue.Queue()
exec_processer = functools.partial(threading.Thread, daemon=True)
processer = exec_processer(
target=self._execute_with_queue,
args=(
result_queue,
examples_to_test,
examples_nocode,
),
)
processer.start()
result = result_queue.get(
timeout=directives.get('timeout', TEST_TIMEOUT)
)
processer.join()
return result
def _execute(self, examples_to_test, examples_nocode):
"""Run xdoctest for each example""" """Run xdoctest for each example"""
# patch xdoctest first in each process
self._patch_xdoctest()
# run the xdoctest
test_results = [] test_results = []
for _, example in examples_to_test.items(): for _, example in examples_to_test.items():
start_time = time.time() start_time = time.time()
...@@ -295,7 +423,7 @@ class Xdoctester(DocTester): ...@@ -295,7 +423,7 @@ class Xdoctester(DocTester):
passed=result['passed'], passed=result['passed'],
skipped=result['skipped'], skipped=result['skipped'],
failed=result['failed'], failed=result['failed'],
test_msg=result['exc_info'], test_msg=str(result['exc_info']),
time=end_time - start_time, time=end_time - start_time,
) )
) )
...@@ -305,10 +433,14 @@ class Xdoctester(DocTester): ...@@ -305,10 +433,14 @@ class Xdoctester(DocTester):
return test_results return test_results
def _execute_with_queue(self, queue, examples_to_test, examples_nocode):
queue.put(self._execute(examples_to_test, examples_nocode))
def print_summary(self, test_results, whl_error=None): def print_summary(self, test_results, whl_error=None):
summary_success = [] summary_success = []
summary_failed = [] summary_failed = []
summary_skiptest = [] summary_skiptest = []
summary_timeout = []
summary_nocodes = [] summary_nocodes = []
stdout_handler = logging.StreamHandler(stream=sys.stdout) stdout_handler = logging.StreamHandler(stream=sys.stdout)
...@@ -335,7 +467,6 @@ class Xdoctester(DocTester): ...@@ -335,7 +467,6 @@ class Xdoctester(DocTester):
logger.info("----------------------------------------------------") logger.info("----------------------------------------------------")
sys.exit(1) sys.exit(1)
else: else:
timeovered_test = {}
for test_result in test_results: for test_result in test_results:
if not test_result.nocode: if not test_result.nocode:
if test_result.passed: if test_result.passed:
...@@ -347,18 +478,16 @@ class Xdoctester(DocTester): ...@@ -347,18 +478,16 @@ class Xdoctester(DocTester):
if test_result.failed: if test_result.failed:
summary_failed.append(test_result.name) summary_failed.append(test_result.name)
if test_result.time > TEST_TIMEOUT: if test_result.timeout:
timeovered_test[test_result.name] = test_result.time summary_timeout.append(
{
'api_name': test_result.name,
'run_time': test_result.time,
}
)
else: else:
summary_nocodes.append(test_result.name) summary_nocodes.append(test_result.name)
if len(timeovered_test):
logger.info(
"%d sample codes ran time over 10s", len(timeovered_test)
)
if self.debug:
for k, v in timeovered_test.items():
logger.info(f'{k} - {v}s')
if len(summary_success): if len(summary_success):
logger.info("%d sample codes ran success", len(summary_success)) logger.info("%d sample codes ran success", len(summary_success))
logger.info('\n'.join(summary_success)) logger.info('\n'.join(summary_success))
...@@ -374,6 +503,13 @@ class Xdoctester(DocTester): ...@@ -374,6 +503,13 @@ class Xdoctester(DocTester):
) )
logger.info('\n'.join(summary_nocodes)) logger.info('\n'.join(summary_nocodes))
if len(summary_timeout):
logger.info("%d sample codes ran timeout", len(summary_timeout))
for _result in summary_timeout:
logger.info(
f"{_result['api_name']} - more than {_result['run_time']}s"
)
if len(summary_failed): if len(summary_failed):
logger.info("%d sample codes ran failed", len(summary_failed)) logger.info("%d sample codes ran failed", len(summary_failed))
logger.info('\n'.join(summary_failed)) logger.info('\n'.join(summary_failed))
......
...@@ -225,8 +225,9 @@ class TestGetTestResults(unittest.TestCase): ...@@ -225,8 +225,9 @@ class TestGetTestResults(unittest.TestCase):
self.assertIn('set_default', tr_1.name) self.assertIn('set_default', tr_1.name)
self.assertTrue(tr_1.passed) self.assertTrue(tr_1.passed)
# tr_2 is passed, because of multiprocessing
self.assertIn('after_set_default', tr_2.name) self.assertIn('after_set_default', tr_2.name)
self.assertFalse(tr_2.passed) self.assertTrue(tr_2.passed)
# test new default global_exec # test new default global_exec
doctester = Xdoctester( doctester = Xdoctester(
...@@ -321,8 +322,9 @@ class TestGetTestResults(unittest.TestCase): ...@@ -321,8 +322,9 @@ class TestGetTestResults(unittest.TestCase):
self.assertIn('enable_static', tr_1.name) self.assertIn('enable_static', tr_1.name)
self.assertTrue(tr_1.passed) self.assertTrue(tr_1.passed)
# tr_2 is passed, because of multiprocessing
self.assertIn('after_enable_static', tr_2.name) self.assertIn('after_enable_static', tr_2.name)
self.assertFalse(tr_2.passed) self.assertTrue(tr_2.passed)
# test new default global_exec # test new default global_exec
doctester = Xdoctester( doctester = Xdoctester(
...@@ -780,7 +782,7 @@ class TestGetTestResults(unittest.TestCase): ...@@ -780,7 +782,7 @@ class TestGetTestResults(unittest.TestCase):
test_capacity = {'cpu'} test_capacity = {'cpu'}
doctester = Xdoctester( doctester = Xdoctester(
style='freeform', target='codeblock', patch_float_precision=False style='freeform', target='codeblock', patch_float_precision=None
) )
doctester.prepare(test_capacity) doctester.prepare(test_capacity)
...@@ -1816,6 +1818,256 @@ class TestGetTestResults(unittest.TestCase): ...@@ -1816,6 +1818,256 @@ class TestGetTestResults(unittest.TestCase):
self.assertTrue(tr_0.skipped) self.assertTrue(tr_0.skipped)
self.assertFalse(tr_0.failed) self.assertFalse(tr_0.failed)
def test_multiprocessing_xdoctester(self):
docstrings_to_test = {
'static_0': """
this is docstring...
Examples:
.. code-block:: python
>>> import numpy as np
>>> import paddle
>>> paddle.enable_static()
>>> data = paddle.static.data(name='X', shape=[None, 2, 28, 28], dtype='float32')
""",
'static_1': """
this is docstring...
Examples:
.. code-block:: python
>>> import numpy as np
>>> import paddle
>>> paddle.enable_static()
>>> data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
""",
}
_clear_environ()
test_capacity = {'cpu'}
doctester = Xdoctester()
doctester.prepare(test_capacity)
test_results = get_test_results(doctester, docstrings_to_test)
self.assertEqual(len(test_results), 2)
tr_0, tr_1 = test_results
self.assertIn('static_0', tr_0.name)
self.assertTrue(tr_0.passed)
self.assertIn('static_1', tr_1.name)
self.assertTrue(tr_1.passed)
_clear_environ()
test_capacity = {'cpu'}
doctester = Xdoctester(use_multiprocessing=False)
doctester.prepare(test_capacity)
test_results = get_test_results(doctester, docstrings_to_test)
self.assertEqual(len(test_results), 2)
tr_0, tr_1 = test_results
self.assertIn('static_0', tr_0.name)
self.assertTrue(tr_0.passed)
self.assertIn('static_1', tr_1.name)
self.assertFalse(tr_1.passed)
self.assertTrue(tr_1.failed)
def test_timeout(self):
docstrings_to_test = {
'timeout_false': """
this is docstring...
Examples:
.. code-block:: python
>>> # doctest: +TIMEOUT(2)
>>> import time
>>> time.sleep(0.1)
""",
'timeout_true': """
this is docstring...
Examples:
.. code-block:: python
>>> # doctest: +TIMEOUT(2)
>>> import time
>>> time.sleep(3)
""",
'timeout_false_with_skip_0': """
this is docstring...
Examples:
.. code-block:: python
>>> # doctest: +TIMEOUT(2)
>>> # doctest: +SKIP
>>> import time
>>> time.sleep(0.1)
""",
'timeout_false_with_skip_1': """
this is docstring...
Examples:
.. code-block:: python
>>> # doctest: +SKIP
>>> # doctest: +TIMEOUT(2)
>>> import time
>>> time.sleep(0.1)
""",
'timeout_true_with_skip_0': """
this is docstring...
Examples:
.. code-block:: python
>>> # doctest: +TIMEOUT(2)
>>> # doctest: +SKIP
>>> import time
>>> time.sleep(3)
""",
'timeout_true_with_skip_1': """
this is docstring...
Examples:
.. code-block:: python
>>> # doctest: +SKIP
>>> # doctest: +TIMEOUT(2)
>>> import time
>>> time.sleep(3)
""",
'timeout_more_codes': """
this is docstring...
Examples:
.. code-block:: python
>>> # doctest: +TIMEOUT(2)
>>> import time
>>> time.sleep(0.1)
.. code-block:: python
>>> # doctest: +TIMEOUT(2)
>>> import time
>>> time.sleep(3)
""",
}
_clear_environ()
test_capacity = {'cpu'}
doctester = Xdoctester()
doctester.prepare(test_capacity)
test_results = get_test_results(doctester, docstrings_to_test)
self.assertEqual(len(test_results), 8)
tr_0, tr_1, tr_2, tr_3, tr_4, tr_5, tr_6, tr_7 = test_results
self.assertIn('timeout_false', tr_0.name)
self.assertTrue(tr_0.passed)
self.assertFalse(tr_0.timeout)
self.assertIn('timeout_true', tr_1.name)
self.assertFalse(tr_1.passed)
self.assertTrue(tr_1.timeout)
self.assertIn('timeout_false_with_skip_0', tr_2.name)
self.assertFalse(tr_2.passed)
self.assertFalse(tr_2.timeout)
self.assertTrue(tr_2.skipped)
self.assertIn('timeout_false_with_skip_1', tr_3.name)
self.assertFalse(tr_3.passed)
self.assertFalse(tr_3.timeout)
self.assertTrue(tr_3.skipped)
self.assertIn('timeout_true_with_skip_0', tr_4.name)
self.assertFalse(tr_4.passed)
self.assertFalse(tr_4.timeout)
self.assertTrue(tr_4.skipped)
self.assertIn('timeout_true_with_skip_1', tr_5.name)
self.assertFalse(tr_5.passed)
self.assertFalse(tr_5.timeout)
self.assertTrue(tr_5.skipped)
self.assertIn('timeout_more_codes', tr_6.name)
self.assertTrue(tr_6.passed)
self.assertFalse(tr_6.timeout)
self.assertIn('timeout_more_codes', tr_7.name)
self.assertFalse(tr_7.passed)
self.assertTrue(tr_7.timeout)
_clear_environ()
test_capacity = {'cpu'}
doctester = Xdoctester(use_multiprocessing=False)
doctester.prepare(test_capacity)
test_results = get_test_results(doctester, docstrings_to_test)
self.assertEqual(len(test_results), 8)
tr_0, tr_1, tr_2, tr_3, tr_4, tr_5, tr_6, tr_7 = test_results
self.assertIn('timeout_false', tr_0.name)
self.assertTrue(tr_0.passed)
self.assertFalse(tr_0.timeout)
self.assertIn('timeout_true', tr_1.name)
self.assertFalse(tr_1.passed)
self.assertTrue(tr_1.timeout)
self.assertIn('timeout_false_with_skip_0', tr_2.name)
self.assertFalse(tr_2.passed)
self.assertFalse(tr_2.timeout)
self.assertTrue(tr_2.skipped)
self.assertIn('timeout_false_with_skip_1', tr_3.name)
self.assertFalse(tr_3.passed)
self.assertFalse(tr_3.timeout)
self.assertTrue(tr_3.skipped)
self.assertIn('timeout_true_with_skip_0', tr_4.name)
self.assertFalse(tr_4.passed)
self.assertFalse(tr_4.timeout)
self.assertTrue(tr_4.skipped)
self.assertIn('timeout_true_with_skip_1', tr_5.name)
self.assertFalse(tr_5.passed)
self.assertFalse(tr_5.timeout)
self.assertTrue(tr_5.skipped)
self.assertIn('timeout_more_codes', tr_6.name)
self.assertTrue(tr_6.passed)
self.assertFalse(tr_6.timeout)
self.assertIn('timeout_more_codes', tr_7.name)
self.assertFalse(tr_7.passed)
self.assertTrue(tr_7.timeout)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册