未验证 提交 4a09c1a1 编写于 作者: R Ren Wei (任卫) 提交者: GitHub

run the sample codes added by `add_sample_code` in ops.py (#31863)

* skip paddle.Tensor.<lambda>

* some file may not exists. such as version.py, it's generated by setup.py

* debug mode

* add unittests for sampcd_processor.py

* add test cases for sampcd_processor

* add test cases for sampcd_processor

* add testcases

* add test cases

* add testcases

* add testcases

* refactor, add testcases

* add import

* all files map to pool. dont split manually

* __all__ += another list

* add testcases

* add testcases

* handle个锤子啊

* this line should not removed

https://github.com/wadefelix/Paddle/commit/882e7f7c3be6c2415f58550f82be338b84f0c0ef#diff-cb0679475bf60202fd803ae05b9146989437c3f787d1502616be6c71c69d0fb1

* print -> logger

* regulate the logging infomation

* regulate the logging infomation

* logger to file

* logger

* threads or subprocesses number config

* follow the good code style

don't touch wlist.json

* run test_sampcd_processor.py, it's a unittest for sampcd_processor.py

* update unittest for sampcd_processor.py

test=document_fix
上级 0624ea56
......@@ -53,6 +53,7 @@ API_FILES=("CMakeLists.txt"
"python/paddle/fluid/tests/unittests/white_list/check_op_sequence_batch_1_input_white_list.py"
"python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py"
"tools/wlist.json"
"tools/sampcd_processor.py"
"paddle/scripts/paddle_build.bat"
"tools/windows/run_unittests.sh"
"tools/parallel_UT_rule.py"
......@@ -79,6 +80,12 @@ function add_failed(){
echo_list="${echo_list[@]}$1"
}
function run_test_sampcd_processor() {
CUR_PWD=$(pwd)
cd ${PADDLE_ROOT}/tools
python test_sampcd_processor.py
cd ${CUR_PWD}
}
if [[ $git_files -gt 19 || $git_count -gt 999 ]];then
echo_line="You must have Dianhai approval for change 20+ files or add than 1000+ lines of content.\n"
......@@ -136,6 +143,9 @@ for API_FILE in ${API_FILES[*]}; do
elif [ "${API_FILE}" == "tools/wlist.json" ];then
echo_line="You must have one TPM (jzhang533) approval for the api whitelist for the tools/wlist.json.\n"
check_approval 1 29231
elif [ "${API_FILE}" == "tools/sampcd_processor.py" ];then
echo_line="test_sampcd_processor.py will be executed for changed sampcd_processor.py.\n"
run_test_sampcd_processor
elif [ "${API_FILE}" == "python/paddle/distributed/fleet/__init__.py" ]; then
echo_line="You must have (fuyinno4 (Recommend), raindrops2sea) approval for ${API_FILE} changes"
check_approval 1 35824027 38231817
......
......@@ -22,6 +22,10 @@ import inspect
import paddle
import paddle.fluid
import json
import argparse
import shutil
import re
import logging
"""
please make sure to run in the tools path
usage: python sample_test.py {arg1}
......@@ -33,6 +37,26 @@ for example, you can run cpu version python2 testing like this:
"""
logger = logging.getLogger()
if logger.handlers:
console = logger.handlers[
0] # we assume the first handler is the one we want to configure
else:
console = logging.StreamHandler()
logger.addHandler(console)
console.setFormatter(
logging.Formatter(
"%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"))
RUN_ON_DEVICE = 'cpu'
GPU_ID = 0
methods = []
whl_error = []
API_DEV_SPEC_FN = 'paddle/fluid/API_DEV.spec'
API_PR_SPEC_FN = 'paddle/fluid/API_PR.spec'
API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec'
SAMPLECODE_TEMPDIR = 'samplecode_temp'
def find_all(srcstr, substr):
"""
......@@ -98,9 +122,13 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
Returns:
result: True or False
name(str): the name of the API.
msg(str): messages
"""
global GPU_ID, RUN_ON_DEVICE, SAMPLECODE_TEMPDIR
result = True
msg = None
def sampcd_header_print(name, sampcd, htype, hname):
"""
......@@ -113,7 +141,8 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
hname(str): the name of the hint banners , e.t. def hname.
flushed.
"""
print_header(htype, hname)
print(htype, " name:", hname)
print("-----------------------")
print("Sample code ", str(y), " extracted for ", name, " :")
print(sampcd)
print("----example code check----\n")
......@@ -122,11 +151,9 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
sampcd_begins = find_all(srccom, " code-block:: python")
if len(sampcd_begins) == 0:
print_header(htype, hname)
'''
detect sample codes using >>> to format
and consider this situation as wrong
'''
# detect sample codes using >>> to format and consider this situation as wrong
print(htype, " name:", hname)
print("-----------------------")
if srccom.find("Examples:") != -1:
print("----example code check----\n")
if srccom.find(">>>") != -1:
......@@ -164,23 +191,22 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
sampcd_to_write.append(cdline[min_indent:])
sampcd = '\n'.join(sampcd_to_write)
if sys.argv[1] == "cpu":
sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd
if sys.argv[1] == "gpu":
sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n' + sampcd
if RUN_ON_DEVICE == "cpu":
sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd
if RUN_ON_DEVICE == "gpu":
sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = "{}"\n'.format(
GPU_ID) + sampcd
sampcd += '\nprint(' + '\"' + name + ' sample code is executed successfully!\")'
if len(sampcd_begins) > 1:
tfname = name + "_example_" + str(y) + ".py"
else:
tfname = name + "_example" + ".py"
tempf = open("samplecode_temp/" + tfname, 'w')
tempf.write(sampcd)
tempf.close()
tfname = os.path.join(SAMPLECODE_TEMPDIR, '{}_example{}'.format(
name, '.py' if len(sampcd_begins) == 1 else '_{}.py'.format(y)))
logging.info('running %s', tfname)
with open(tfname, 'w') as tempf:
tempf.write(sampcd)
if platform.python_version()[0] == "2":
cmd = ["python", "samplecode_temp/" + tfname]
cmd = ["python", tfname]
elif platform.python_version()[0] == "3":
cmd = ["python3", "samplecode_temp/" + tfname]
cmd = ["python3", tfname]
else:
print("Error: fail to parse python version!")
result = False
......@@ -199,11 +225,12 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
print("Error Raised from Sample Code ", name, " :\n")
print(err)
print(msg)
logging.warning('%s error: %s', tfname, err)
logging.warning('%s msg: %s', tfname, msg)
result = False
# msg is the returned code execution report
#os.remove("samplecode_temp/" + tfname)
return result
return result, name, msg
def single_defcom_extract(start_from, srcls, is_class_begin=False):
......@@ -264,12 +291,7 @@ def single_defcom_extract(start_from, srcls, is_class_begin=False):
return fcombody
def print_header(htype, name):
print(htype, " name:", name)
print("-----------------------")
def srccoms_extract(srcfile, wlist):
def srccoms_extract(srcfile, wlist, methods):
"""
Given a source file ``srcfile``, this function will
extract its API(doc comments) and run sample codes in the
......@@ -278,12 +300,15 @@ def srccoms_extract(srcfile, wlist):
Args:
srcfile(file): the source file
wlist(list): white list
methods(list): only elements of this list considered.
Returns:
result: True or False
error_methods: the methods that failed.
"""
process_result = True
error_methods = []
srcc = srcfile.read()
# 2. get defs and classes header line number
# set file pointer to its beginning
......@@ -292,8 +317,8 @@ def srccoms_extract(srcfile, wlist):
# 1. fetch__all__ list
allidx = srcc.find("__all__")
srcfile_new = srcfile.name
srcfile_new = srcfile_new.replace('.py', '')
logger.debug('processing %s, methods: %s', srcfile.name, str(methods))
srcfile_new, _ = os.path.splitext(srcfile.name)
srcfile_list = srcfile_new.split('/')
srcfile_str = ''
for i in range(4, len(srcfile_list)):
......@@ -323,15 +348,27 @@ def srccoms_extract(srcfile, wlist):
if '' in alllist:
alllist.remove('')
api_alllist_count = len(alllist)
logger.debug('found %d items: %s', api_alllist_count, str(alllist))
api_count = 0
handled = []
# get src contents in layers/ops.py
if srcfile.name.find("ops.py") != -1:
for i in range(0, len(srcls)):
if srcls[i].find("__doc__") != -1:
opname = srcls[i][:srcls[i].find("__doc__") - 1]
opname = None
opres = re.match(r"^(\w+)\.__doc__", srcls[i])
if opres is not None:
opname = opres.group(1)
else:
opres = re.match(
r"^add_sample_code\(globals\(\)\[\"(\w+)\"\]", srcls[i])
if opres is not None:
opname = opres.group(1)
if opname is not None:
if opname in wlist:
logger.info('%s is in the whitelist, skip it.', opname)
continue
else:
logger.debug('%s\'s docstring found.', opname)
comstart = i
for j in range(i, len(srcls)):
if srcls[j].find("\"\"\"") != -1:
......@@ -341,51 +378,73 @@ def srccoms_extract(srcfile, wlist):
opcom += srcls[j]
if srcls[j].find("\"\"\"") != -1:
break
result, _, _ = sampcd_extract_and_run(opcom, opname, "def",
opname)
if not result:
error_methods.append(opname)
process_result = False
api_count += 1
handled.append(
opname) # ops.py also has normal formatted functions
# use list 'handled' to mark the functions have been handled here
# which will be ignored in the following step
# handled what?
logger.debug('%s already handled.', str(handled))
for i in range(0, len(srcls)):
if srcls[i].startswith(
'def '): # a function header is detected in line i
f_header = srcls[i].replace(" ", '')
fn = f_header[len('def'):f_header.find('(')] # function name
if "%s%s" % (srcfile_str, fn) not in methods:
logger.info(
'[file:%s, function:%s] not in methods list, skip it.',
srcfile_str, fn)
continue
if fn in handled:
continue
if fn in alllist:
api_count += 1
if fn in wlist or fn + "@" + srcfile.name in wlist:
logger.info('[file:%s, function:%s] skip by wlist.',
srcfile_str, fn)
continue
fcombody = single_defcom_extract(i, srcls)
if fcombody == "": # if no comment
print_header("def", fn)
print("def name:", fn)
print("-----------------------")
print("WARNING: no comments in function ", fn,
", but it deserves.")
continue
else:
if not sampcd_extract_and_run(fcombody, fn, "def", fn):
result, _, _ = sampcd_extract_and_run(fcombody, fn,
"def", fn)
if not result:
error_methods.append(fn)
process_result = False
if srcls[i].startswith('class '):
c_header = srcls[i].replace(" ", '')
cn = c_header[len('class'):c_header.find('(')] # class name
if '%s%s' % (srcfile_str, cn) not in methods:
logger.info(
'[file:%s, class:%s] not in methods list, skip it.',
srcfile_str, cn)
continue
if cn in handled:
continue
if cn in alllist:
api_count += 1
if cn in wlist or cn + "@" + srcfile.name in wlist:
logger.info('[file:%s, class:%s] skip by wlist.',
srcfile_str, cn)
continue
# class comment
classcom = single_defcom_extract(i, srcls, True)
if classcom != "":
if not sampcd_extract_and_run(classcom, cn, "class",
cn):
result, _, _ = sampcd_extract_and_run(classcom, cn,
"class", cn)
if not result:
error_methods.append(cn)
process_result = False
else:
print("WARNING: no comments in class itself ", cn,
......@@ -410,10 +469,19 @@ def srccoms_extract(srcfile, wlist):
if '%s%s' % (
srcfile_str, name
) not in methods: # class method not in api.spec
logger.info(
'[file:%s, func:%s] not in methods, skip it.',
srcfile_str, name)
continue
if mn.startswith('_'):
logger.info(
'[file:%s, func:%s] startswith _, it\'s private method, skip it.',
srcfile_str, name)
continue
if name in wlist or name + "@" + srcfile.name in wlist:
logger.info(
'[file:%s, class:%s] skip by wlist.',
srcfile_str, name)
continue
thismethod = [thisl[indent:]
] # method body lines
......@@ -434,22 +502,38 @@ def srccoms_extract(srcfile, wlist):
thismtdcom = single_defcom_extract(0,
thismethod)
if thismtdcom != "":
if not sampcd_extract_and_run(
thismtdcom, name, "method", name):
result, _, _ = sampcd_extract_and_run(
thismtdcom, name, "method", name)
if not result:
error_methods.append(name)
process_result = False
else:
logger.warning('__all__ not found in file:%s', srcfile.name)
return process_result
return process_result, error_methods
def test(file_list):
global methods # readonly
process_result = True
for file in file_list:
with open(file, 'r') as src:
if not srccoms_extract(src, wlist):
if not srccoms_extract(src, wlist, methods):
process_result = False
return process_result
def run_a_test(tc_filename):
"""
execute a sample code-block.
"""
global methods # readonly
process_result = True
with open(tc_filename, 'r') as src:
process_result, error_methods = srccoms_extract(src, wlist, methods)
return process_result, tc_filename, error_methods
def get_filenames():
'''
this function will get the modules that pending for check.
......@@ -460,12 +544,12 @@ def get_filenames():
'''
filenames = []
global methods
global methods # write
global whl_error
methods = []
whl_error = []
get_incrementapi()
API_spec = 'dev_pr_diff_api.spec'
API_spec = API_DIFF_SPEC_FN
with open(API_spec) as f:
for line in f.readlines():
api = line.replace('\n', '')
......@@ -474,17 +558,30 @@ def get_filenames():
except AttributeError:
whl_error.append(api)
continue
except SyntaxError:
logger.warning('line:%s, api:%s', line, api)
# paddle.Tensor.<lambda>
continue
if len(module.split('.')) > 1:
filename = '../python/'
# work for .so?
module_py = '%s.py' % module.split('.')[-1]
for i in range(0, len(module.split('.')) - 1):
filename = filename + '%s/' % module.split('.')[i]
filename = filename + module_py
else:
filename = ''
print("\nWARNING:----Exception in get api filename----\n")
print("\n" + api + ' module is ' + module + "\n")
if filename != '' and filename not in filenames:
logger.warning("WARNING: Exception in getting api:%s module:%s",
api, module)
if filename in filenames:
continue
elif not filename:
logger.warning('filename invalid: %s', line)
continue
elif not os.path.exists(filename):
logger.warning('file not exists: %s', filename)
continue
else:
filenames.append(filename)
# get all methods
method = ''
......@@ -496,9 +593,9 @@ def get_filenames():
name = '%s.%s' % (api.split('.')[-2], api.split('.')[-1])
else:
name = ''
print("\nWARNING:----Exception in get api methods----\n")
print("\n" + line + "\n")
print("\n" + api + ' method is None!!!' + "\n")
logger.warning(
"WARNING: Exception when getting api:%s, line:%s", api,
line)
for j in range(2, len(module.split('.'))):
method = method + '%s.' % module.split('.')[j]
method = method + name
......@@ -508,26 +605,27 @@ def get_filenames():
return filenames
def get_api_md5(path):
api_md5 = {}
API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")),
path)
with open(API_spec) as f:
for line in f.readlines():
api = line.split(' ', 1)[0]
md5 = line.split("'document', ")[1].replace(')', '').replace('\n',
'')
api_md5[api] = md5
return api_md5
def get_incrementapi():
'''
this function will get the apis that difference between API_DEV.spec and API_PR.spec.
'''
def get_api_md5(path):
api_md5 = {}
API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")),
path)
with open(API_spec) as f:
for line in f.readlines():
api = line.split(' ', 1)[0]
md5 = line.split("'document', ")[1].replace(')', '').replace(
'\n', '')
api_md5[api] = md5
return api_md5
dev_api = get_api_md5('paddle/fluid/API_DEV.spec')
pr_api = get_api_md5('paddle/fluid/API_PR.spec')
with open('dev_pr_diff_api.spec', 'w') as f:
global API_DEV_SPEC_FN, API_PR_SPEC_FN, API_DIFF_SPEC_FN ## readonly
dev_api = get_api_md5(API_DEV_SPEC_FN)
pr_api = get_api_md5(API_PR_SPEC_FN)
with open(API_DIFF_SPEC_FN, 'w') as f:
for key in pr_api:
if key in dev_api:
if dev_api[key] != pr_api[key]:
......@@ -538,7 +636,7 @@ def get_incrementapi():
f.write('\n')
def get_wlist():
def get_wlist(fn="wlist.json"):
'''
this function will get the white list of API.
......@@ -551,7 +649,7 @@ def get_wlist():
wlist_file = []
# only white on CPU
gpu_not_white = []
with open("wlist.json", 'r') as load_f:
with open(fn, 'r') as load_f:
load_dict = json.load(load_f)
for key in load_dict:
if key == 'wlist_dir':
......@@ -567,31 +665,77 @@ def get_wlist():
return wlist, wlist_file, gpu_not_white
wlist, wlist_file, gpu_not_white = get_wlist()
arguments = [
# flags, dest, type, default, help
['--gpu_id', 'gpu_id', int, 0, 'GPU device id to use [0]'],
['--logf', 'logf', str, None, 'file for logging'],
['--threads', 'threads', int, 0, 'sub processes number'],
]
if len(sys.argv) < 2:
print("Error: inadequate number of arguments")
print('''If you are going to run it on
"CPU: >>> python sampcd_processor.py cpu
"GPU: >>> python sampcd_processor.py gpu
''')
sys.exit("lack arguments")
else:
if sys.argv[1] == "gpu":
def parse_args():
"""
Parse input arguments
"""
global arguments
parser = argparse.ArgumentParser(description='run Sample Code Test')
# parser.add_argument('--cpu', dest='cpu_mode', action="store_true",
# help='Use CPU mode (overrides --gpu)')
# parser.add_argument('--gpu', dest='gpu_mode', action="store_true")
parser.add_argument('--debug', dest='debug', action="store_true")
parser.add_argument('mode', type=str, help='run on device', default='cpu')
for item in arguments:
parser.add_argument(
item[0], dest=item[1], help=item[4], type=item[2], default=item[3])
if len(sys.argv) == 1:
args = parser.parse_args(['cpu'])
return args
# parser.print_help()
# sys.exit(1)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.debug:
logger.setLevel(logging.DEBUG)
if args.logf:
logfHandler = logging.FileHandler(args.logf)
logfHandler.setFormatter(
logging.Formatter(
"%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"
))
logger.addHandler(logfHandler)
wlist, wlist_file, gpu_not_white = get_wlist()
if args.mode == "gpu":
GPU_ID = args.gpu_id
logger.info("using GPU_ID %d", GPU_ID)
for _gnw in gpu_not_white:
wlist.remove(_gnw)
elif sys.argv[1] != "cpu":
print("Unrecognized argument:'", sys.argv[1], "' , 'cpu' or 'gpu' is ",
"desired\n")
elif args.mode != "cpu":
logger.error("Unrecognized argument:%s, 'cpu' or 'gpu' is desired.",
args.mode)
sys.exit("Invalid arguments")
print("API check -- Example Code")
print("sample_test running under python", platform.python_version())
if not os.path.isdir("./samplecode_temp"):
os.mkdir("./samplecode_temp")
cpus = multiprocessing.cpu_count()
RUN_ON_DEVICE = args.mode
logger.info("API check -- Example Code")
logger.info("sample_test running under python %s",
platform.python_version())
if os.path.exists(SAMPLECODE_TEMPDIR):
if not os.path.isdir(SAMPLECODE_TEMPDIR):
os.remove(SAMPLECODE_TEMPDIR)
os.mkdir(SAMPLECODE_TEMPDIR)
else:
os.mkdir(SAMPLECODE_TEMPDIR)
filenames = get_filenames()
if len(filenames) == 0 and len(whl_error) == 0:
print("-----API_PR.spec is the same as API_DEV.spec-----")
logger.info("-----API_PR.spec is the same as API_DEV.spec-----")
exit(0)
rm_file = []
for f in filenames:
......@@ -600,51 +744,52 @@ else:
rm_file.append(f)
filenames.remove(f)
if len(rm_file) != 0:
print("REMOVE white files: %s" % rm_file)
print("API_PR is diff from API_DEV: %s" % filenames)
one_part_filenum = int(math.ceil(len(filenames) / cpus))
if one_part_filenum == 0:
one_part_filenum = 1
divided_file_list = [
filenames[i:i + one_part_filenum]
for i in range(0, len(filenames), one_part_filenum)
]
po = multiprocessing.Pool()
results = po.map_async(test, divided_file_list)
logger.info("REMOVE white files: %s", rm_file)
logger.info("API_PR is diff from API_DEV: %s", filenames)
threads = multiprocessing.cpu_count()
if args.threads:
threads = args.threads
po = multiprocessing.Pool(threads)
# results = po.map_async(test, divided_file_list)
results = po.map_async(run_a_test, filenames)
po.close()
po.join()
result = results.get()
# delete temp files
for root, dirs, files in os.walk("./samplecode_temp"):
for fntemp in files:
os.remove("./samplecode_temp/" + fntemp)
os.rmdir("./samplecode_temp")
if not args.debug:
shutil.rmtree(SAMPLECODE_TEMPDIR)
print("----------------End of the Check--------------------")
logger.info("----------------End of the Check--------------------")
if len(whl_error) != 0:
print("%s is not in whl." % whl_error)
print("")
print("Please check the whl package and API_PR.spec!")
print("You can follow these steps in order to generate API.spec:")
print("1. cd ${paddle_path}, compile paddle;")
print("2. pip install build/python/dist/(build whl package);")
print(
logger.info("%s is not in whl.", whl_error)
logger.info("")
logger.info("Please check the whl package and API_PR.spec!")
logger.info("You can follow these steps in order to generate API.spec:")
logger.info("1. cd ${paddle_path}, compile paddle;")
logger.info("2. pip install build/python/dist/(build whl package);")
logger.info(
"3. run 'python tools/print_signatures.py paddle > paddle/fluid/API.spec'."
)
for temp in result:
if not temp:
print("")
print("In addition, mistakes found in sample codes.")
print("Please check sample codes.")
print("----------------------------------------------------")
if not temp[0]:
logger.info("In addition, mistakes found in sample codes: %s",
temp[1])
logger.info("error_methods: %s", str(temp[2]))
logger.info("----------------------------------------------------")
exit(1)
else:
has_error = False
for temp in result:
if not temp:
print("Mistakes found in sample codes.")
print("Please check sample codes.")
exit(1)
print("Sample code check is successful!")
if not temp[0]:
logger.info("In addition, mistakes found in sample codes: %s",
temp[1])
logger.info("error_methods: %s", str(temp[2]))
has_error = True
if has_error:
logger.info("Mistakes found in sample codes.")
logger.info("Please check sample codes.")
exit(1)
logger.info("Sample code check is successful!")
#! python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import tempfile
import shutil
import sys
import importlib
from sampcd_processor import find_all
from sampcd_processor import check_indent
from sampcd_processor import sampcd_extract_and_run
from sampcd_processor import single_defcom_extract
from sampcd_processor import srccoms_extract
from sampcd_processor import get_api_md5
from sampcd_processor import get_incrementapi
from sampcd_processor import get_wlist
class Test_find_all(unittest.TestCase):
def test_find_none(self):
self.assertEqual(0, len(find_all('hello', 'world')))
def test_find_one(self):
self.assertListEqual([0], find_all('hello', 'hello'))
def test_find_two(self):
self.assertListEqual([1, 15],
find_all(' hello, world; hello paddle!', 'hello'))
class Test_check_indent(unittest.TestCase):
def test_no_indent(self):
self.assertEqual(0, check_indent('hello paddle'))
def test_indent_4_spaces(self):
self.assertEqual(4, check_indent(' hello paddle'))
def test_indent_1_tab(self):
self.assertEqual(4, check_indent("\thello paddle"))
class Test_sampcd_extract_and_run(unittest.TestCase):
def setUp(self):
if not os.path.exists('samplecode_temp/'):
os.mkdir('samplecode_temp/')
def test_run_a_defs_samplecode(self):
comments = """
Examples:
.. code-block:: python
print(1+1)
"""
funcname = 'one_plus_one'
res, name, msg = sampcd_extract_and_run(comments, funcname)
self.assertTrue(res)
self.assertEqual(funcname, name)
def test_run_a_def_no_code(self):
comments = """
placeholder
"""
funcname = 'one_plus_one'
res, name, msg = sampcd_extract_and_run(comments, funcname)
self.assertFalse(res)
self.assertEqual(funcname, name)
def test_run_a_def_raise_expection(self):
comments = """
placeholder
Examples:
.. code-block:: python
print(1/0)
"""
funcname = 'one_plus_one'
res, name, msg = sampcd_extract_and_run(comments, funcname)
self.assertFalse(res)
self.assertEqual(funcname, name)
class Test_single_defcom_extract(unittest.TestCase):
def test_extract_from_func(self):
defstr = '''
import os
def foo():
"""
foo is a function.
"""
pass
def bar():
pass
'''
comm = single_defcom_extract(
2, defstr.splitlines(True), is_class_begin=False)
self.assertEqual(" foo is a function.\n", comm)
pass
def test_extract_from_func_with_no_docstring(self):
defstr = '''
import os
def bar():
pass
'''
comm = single_defcom_extract(
2, defstr.splitlines(True), is_class_begin=False)
self.assertEqual('', comm)
pass
def test_extract_from_class(self):
defstr = r'''
import os
class Foo():
"""
Foo is a class.
second line.
"""
pass
def bar():
pass
def foo():
pass
'''
comm = single_defcom_extract(
2, defstr.splitlines(True), is_class_begin=True)
rcomm = """ Foo is a class.
second line.
"""
self.assertEqual(rcomm, comm)
pass
def test_extract_from_class_with_no_docstring(self):
defstr = '''
import os
class Foo():
pass
def bar():
pass
def foo():
pass
'''
comm = single_defcom_extract(
0, defstr.splitlines(True), is_class_begin=True)
self.assertEqual('', comm)
class Test_get_api_md5(unittest.TestCase):
def setUp(self):
self.api_pr_spec_filename = os.path.abspath(
os.path.join(os.getcwd(), "..", 'paddle/fluid/API_PR.spec'))
with open(self.api_pr_spec_filename, 'w') as f:
f.write("\n".join([
"""one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of one_plus_one'))""",
"""two_plus_two (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of two_plus_two'))""",
"""three_plus_three (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of three_plus_three'))""",
"""four_plus_four (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of four_plus_four'))""",
]))
def tearDown(self):
os.remove(self.api_pr_spec_filename)
pass
def test_get_api_md5(self):
res = get_api_md5('paddle/fluid/API_PR.spec')
self.assertEqual("'md5sum of one_plus_one'", res['one_plus_one'])
self.assertEqual("'md5sum of two_plus_two'", res['two_plus_two'])
self.assertEqual("'md5sum of three_plus_three'",
res['three_plus_three'])
self.assertEqual("'md5sum of four_plus_four'", res['four_plus_four'])
class Test_get_incrementapi(unittest.TestCase):
def setUp(self):
self.api_pr_spec_filename = os.path.abspath(
os.path.join(os.getcwd(), "..", 'paddle/fluid/API_PR.spec'))
with open(self.api_pr_spec_filename, 'w') as f:
f.write("\n".join([
"""one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of one_plus_one'))""",
"""two_plus_two (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of two_plus_two'))""",
"""three_plus_three (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of three_plus_three'))""",
"""four_plus_four (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of four_plus_four'))""",
]))
self.api_dev_spec_filename = os.path.abspath(
os.path.join(os.getcwd(), "..", 'paddle/fluid/API_DEV.spec'))
with open(self.api_dev_spec_filename, 'w') as f:
f.write("\n".join([
"""one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of one_plus_one'))""",
]))
self.api_diff_spec_filename = os.path.abspath(
os.path.join(os.getcwd(), "dev_pr_diff_api.spec"))
def tearDown(self):
os.remove(self.api_pr_spec_filename)
os.remove(self.api_dev_spec_filename)
os.remove(self.api_diff_spec_filename)
def test_it(self):
get_incrementapi()
with open(self.api_diff_spec_filename, 'r') as f:
lines = f.readlines()
self.assertCountEqual(
["two_plus_two\n", "three_plus_three\n", "four_plus_four\n"],
lines)
class Test_get_wlist(unittest.TestCase):
def setUp(self):
self.tmpDir = tempfile.mkdtemp()
self.wlist_filename = os.path.join(self.tmpDir, 'wlist.json')
with open(self.wlist_filename, 'w') as f:
f.write(r'''
{
"wlist_dir":[
{
"name":"../python/paddle/fluid/contrib",
"annotation":""
},
{
"name":"../python/paddle/verison.py",
"annotation":""
}
],
"wlist_api":[
{
"name":"xxxxx",
"annotation":"not a real api, just for example"
}
],
"wlist_temp_api":[
"to_tensor",
"save_persistables@dygraph/checkpoint.py"
],
"gpu_not_white":[
"deformable_conv"
]
}
''')
def tearDown(self):
os.remove(self.wlist_filename)
shutil.rmtree(self.tmpDir)
def test_get_wlist(self):
wlist, wlist_file, gpu_not_white = get_wlist(self.wlist_filename)
self.assertCountEqual(
["xxxxx", "to_tensor",
"save_persistables@dygraph/checkpoint.py"], wlist)
self.assertCountEqual([
"../python/paddle/fluid/contrib",
"../python/paddle/verison.py",
], wlist_file)
self.assertCountEqual(["deformable_conv"], gpu_not_white)
class Test_srccoms_extract(unittest.TestCase):
def setUp(self):
self.tmpDir = tempfile.mkdtemp()
sys.path.append(self.tmpDir)
self.api_pr_spec_filename = os.path.abspath(
os.path.join(os.getcwd(), "..", 'paddle/fluid/API_PR.spec'))
with open(self.api_pr_spec_filename, 'w') as f:
f.write("\n".join([
"""one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "one_plus_one"))""",
"""two_plus_two (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "two_plus_two"))""",
"""three_plus_three (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "three_plus_three"))""",
"""four_plus_four (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "four_plus_four"))""",
]))
def tearDown(self):
sys.path.remove(self.tmpDir)
shutil.rmtree(self.tmpDir)
os.remove(self.api_pr_spec_filename)
def test_from_ops_py(self):
filecont = '''
def add_sample_code(obj, docstr):
pass
__unary_func__ = [
'exp',
]
__all__ = []
__all__ += __unary_func__
__all__ += ['one_plus_one']
def exp():
pass
add_sample_code(globals()["exp"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.exp(x)
print(out)
# [0.67032005 0.81873075 1.10517092 1.34985881]
""")
def one_plus_one():
return 1+1
one_plus_one.__doc__ = """
placeholder
Examples:
.. code-block:: python
print(1+1)
"""
__all__ += ['two_plus_two']
def two_plus_two():
return 2+2
add_sample_code(globals()["two_plus_two"], """
Examples:
.. code-block:: python
print(2+2)
""")
'''
pyfilename = os.path.join(self.tmpDir, 'ops.py')
with open(pyfilename, 'w') as pyfile:
pyfile.write(filecont)
self.assertTrue(os.path.exists(pyfilename))
utsp = importlib.import_module('ops')
print('testing srccoms_extract from ops.py')
methods = ['one_plus_one', 'two_plus_two', 'exp']
# os.remove("samplecode_temp/" "one_plus_one_example.py")
self.assertFalse(
os.path.exists("samplecode_temp/"
"one_plus_one_example.py"))
with open(pyfilename, 'r') as pyfile:
res, error_methods = srccoms_extract(pyfile, [], methods)
self.assertTrue(res)
self.assertTrue(
os.path.exists("samplecode_temp/"
"one_plus_one_example.py"))
os.remove("samplecode_temp/" "one_plus_one_example.py")
self.assertTrue(
os.path.exists("samplecode_temp/"
"two_plus_two_example.py"))
os.remove("samplecode_temp/" "two_plus_two_example.py")
self.assertTrue(os.path.exists("samplecode_temp/" "exp_example.py"))
os.remove("samplecode_temp/" "exp_example.py")
def test_from_not_ops_py(self):
filecont = '''
__all__ = [
'one_plus_one'
]
def one_plus_one():
"""
placeholder
Examples:
.. code-block:: python
print(1+1)
"""
return 1+1
'''
pyfilename = os.path.join(self.tmpDir, 'opo.py')
with open(pyfilename, 'w') as pyfile:
pyfile.write(filecont)
utsp = importlib.import_module('opo')
methods = ['one_plus_one']
with open(pyfilename, 'r') as pyfile:
res, error_methods = srccoms_extract(pyfile, [], methods)
self.assertTrue(res)
self.assertTrue(
os.path.exists("samplecode_temp/"
"one_plus_one_example.py"))
os.remove("samplecode_temp/" "one_plus_one_example.py")
def test_with_empty_wlist(self):
"""
see test_from_ops_py
"""
pass
def test_with_wlist(self):
filecont = '''
__all__ = [
'four_plus_four',
'three_plus_three'
]
def four_plus_four():
"""
placeholder
Examples:
.. code-block:: python
print(4+4)
"""
return 4+4
def three_plus_three():
"""
placeholder
Examples:
.. code-block:: python
print(3+3)
"""
return 3+3
'''
pyfilename = os.path.join(self.tmpDir, 'three_and_four.py')
with open(pyfilename, 'w') as pyfile:
pyfile.write(filecont)
utsp = importlib.import_module('three_and_four')
methods = ['four_plus_four', 'three_plus_three']
with open(pyfilename, 'r') as pyfile:
res, error_methods = srccoms_extract(pyfile, ['three_plus_three'],
methods)
self.assertTrue(res)
self.assertTrue(
os.path.exists("samplecode_temp/four_plus_four_example.py"))
os.remove("samplecode_temp/" "four_plus_four_example.py")
self.assertFalse(
os.path.exists("samplecode_temp/three_plus_three_example.py"))
# https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/layers/ops.py
# why? unabled to use the ast module. emmmmm
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册