diff --git a/tools/check_file_diff_approvals.sh b/tools/check_file_diff_approvals.sh index f3bf3ea508ba71a5ea8d504ee6be5525700e767d..05466883e58d28ab16ab4ca7e561672eb392185f 100644 --- a/tools/check_file_diff_approvals.sh +++ b/tools/check_file_diff_approvals.sh @@ -53,6 +53,7 @@ API_FILES=("CMakeLists.txt" "python/paddle/fluid/tests/unittests/white_list/check_op_sequence_batch_1_input_white_list.py" "python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py" "tools/wlist.json" + "tools/sampcd_processor.py" "paddle/scripts/paddle_build.bat" "tools/windows/run_unittests.sh" "tools/parallel_UT_rule.py" @@ -79,6 +80,12 @@ function add_failed(){ echo_list="${echo_list[@]}$1" } +function run_test_sampcd_processor() { + CUR_PWD=$(pwd) + cd ${PADDLE_ROOT}/tools + python test_sampcd_processor.py + cd ${CUR_PWD} +} if [[ $git_files -gt 19 || $git_count -gt 999 ]];then echo_line="You must have Dianhai approval for change 20+ files or add than 1000+ lines of content.\n" @@ -136,6 +143,9 @@ for API_FILE in ${API_FILES[*]}; do elif [ "${API_FILE}" == "tools/wlist.json" ];then echo_line="You must have one TPM (jzhang533) approval for the api whitelist for the tools/wlist.json.\n" check_approval 1 29231 + elif [ "${API_FILE}" == "tools/sampcd_processor.py" ];then + echo_line="test_sampcd_processor.py will be executed for changed sampcd_processor.py.\n" + run_test_sampcd_processor elif [ "${API_FILE}" == "python/paddle/distributed/fleet/__init__.py" ]; then echo_line="You must have (fuyinno4 (Recommend), raindrops2sea) approval for ${API_FILE} changes" check_approval 1 35824027 38231817 diff --git a/tools/sampcd_processor.py b/tools/sampcd_processor.py index ce0490d487fbe7798cba06e7ff0c11b457a18979..fde01329340b2ba5c8043c8b18c2558539667fc2 100644 --- a/tools/sampcd_processor.py +++ b/tools/sampcd_processor.py @@ -22,6 +22,10 @@ import inspect import paddle import paddle.fluid import json +import argparse +import shutil +import re +import logging """ please make sure to run in the tools path usage: python sample_test.py {arg1} @@ -33,6 +37,26 @@ for example, you can run cpu version python2 testing like this: """ +logger = logging.getLogger() +if logger.handlers: + console = logger.handlers[ + 0] # we assume the first handler is the one we want to configure +else: + console = logging.StreamHandler() + logger.addHandler(console) +console.setFormatter( + logging.Formatter( + "%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s")) + +RUN_ON_DEVICE = 'cpu' +GPU_ID = 0 +methods = [] +whl_error = [] +API_DEV_SPEC_FN = 'paddle/fluid/API_DEV.spec' +API_PR_SPEC_FN = 'paddle/fluid/API_PR.spec' +API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec' +SAMPLECODE_TEMPDIR = 'samplecode_temp' + def find_all(srcstr, substr): """ @@ -98,9 +122,13 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""): Returns: result: True or False + name(str): the name of the API. + msg(str): messages """ + global GPU_ID, RUN_ON_DEVICE, SAMPLECODE_TEMPDIR result = True + msg = None def sampcd_header_print(name, sampcd, htype, hname): """ @@ -113,7 +141,8 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""): hname(str): the name of the hint banners , e.t. def hname. flushed. """ - print_header(htype, hname) + print(htype, " name:", hname) + print("-----------------------") print("Sample code ", str(y), " extracted for ", name, " :") print(sampcd) print("----example code check----\n") @@ -122,11 +151,9 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""): sampcd_begins = find_all(srccom, " code-block:: python") if len(sampcd_begins) == 0: - print_header(htype, hname) - ''' - detect sample codes using >>> to format - and consider this situation as wrong - ''' + # detect sample codes using >>> to format and consider this situation as wrong + print(htype, " name:", hname) + print("-----------------------") if srccom.find("Examples:") != -1: print("----example code check----\n") if srccom.find(">>>") != -1: @@ -164,23 +191,22 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""): sampcd_to_write.append(cdline[min_indent:]) sampcd = '\n'.join(sampcd_to_write) - if sys.argv[1] == "cpu": - sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd - if sys.argv[1] == "gpu": - sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n' + sampcd + if RUN_ON_DEVICE == "cpu": + sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd + if RUN_ON_DEVICE == "gpu": + sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = "{}"\n'.format( + GPU_ID) + sampcd sampcd += '\nprint(' + '\"' + name + ' sample code is executed successfully!\")' - if len(sampcd_begins) > 1: - tfname = name + "_example_" + str(y) + ".py" - else: - tfname = name + "_example" + ".py" - tempf = open("samplecode_temp/" + tfname, 'w') - tempf.write(sampcd) - tempf.close() + tfname = os.path.join(SAMPLECODE_TEMPDIR, '{}_example{}'.format( + name, '.py' if len(sampcd_begins) == 1 else '_{}.py'.format(y))) + logging.info('running %s', tfname) + with open(tfname, 'w') as tempf: + tempf.write(sampcd) if platform.python_version()[0] == "2": - cmd = ["python", "samplecode_temp/" + tfname] + cmd = ["python", tfname] elif platform.python_version()[0] == "3": - cmd = ["python3", "samplecode_temp/" + tfname] + cmd = ["python3", tfname] else: print("Error: fail to parse python version!") result = False @@ -199,11 +225,12 @@ def sampcd_extract_and_run(srccom, name, htype="def", hname=""): print("Error Raised from Sample Code ", name, " :\n") print(err) print(msg) + logging.warning('%s error: %s', tfname, err) + logging.warning('%s msg: %s', tfname, msg) result = False # msg is the returned code execution report - #os.remove("samplecode_temp/" + tfname) - return result + return result, name, msg def single_defcom_extract(start_from, srcls, is_class_begin=False): @@ -264,12 +291,7 @@ def single_defcom_extract(start_from, srcls, is_class_begin=False): return fcombody -def print_header(htype, name): - print(htype, " name:", name) - print("-----------------------") - - -def srccoms_extract(srcfile, wlist): +def srccoms_extract(srcfile, wlist, methods): """ Given a source file ``srcfile``, this function will extract its API(doc comments) and run sample codes in the @@ -278,12 +300,15 @@ def srccoms_extract(srcfile, wlist): Args: srcfile(file): the source file wlist(list): white list + methods(list): only elements of this list considered. Returns: result: True or False + error_methods: the methods that failed. """ process_result = True + error_methods = [] srcc = srcfile.read() # 2. get defs and classes header line number # set file pointer to its beginning @@ -292,8 +317,8 @@ def srccoms_extract(srcfile, wlist): # 1. fetch__all__ list allidx = srcc.find("__all__") - srcfile_new = srcfile.name - srcfile_new = srcfile_new.replace('.py', '') + logger.debug('processing %s, methods: %s', srcfile.name, str(methods)) + srcfile_new, _ = os.path.splitext(srcfile.name) srcfile_list = srcfile_new.split('/') srcfile_str = '' for i in range(4, len(srcfile_list)): @@ -323,15 +348,27 @@ def srccoms_extract(srcfile, wlist): if '' in alllist: alllist.remove('') api_alllist_count = len(alllist) + logger.debug('found %d items: %s', api_alllist_count, str(alllist)) api_count = 0 handled = [] # get src contents in layers/ops.py if srcfile.name.find("ops.py") != -1: for i in range(0, len(srcls)): - if srcls[i].find("__doc__") != -1: - opname = srcls[i][:srcls[i].find("__doc__") - 1] + opname = None + opres = re.match(r"^(\w+)\.__doc__", srcls[i]) + if opres is not None: + opname = opres.group(1) + else: + opres = re.match( + r"^add_sample_code\(globals\(\)\[\"(\w+)\"\]", srcls[i]) + if opres is not None: + opname = opres.group(1) + if opname is not None: if opname in wlist: + logger.info('%s is in the whitelist, skip it.', opname) continue + else: + logger.debug('%s\'s docstring found.', opname) comstart = i for j in range(i, len(srcls)): if srcls[j].find("\"\"\"") != -1: @@ -341,51 +378,73 @@ def srccoms_extract(srcfile, wlist): opcom += srcls[j] if srcls[j].find("\"\"\"") != -1: break + result, _, _ = sampcd_extract_and_run(opcom, opname, "def", + opname) + if not result: + error_methods.append(opname) + process_result = False api_count += 1 handled.append( opname) # ops.py also has normal formatted functions # use list 'handled' to mark the functions have been handled here # which will be ignored in the following step + # handled what? + logger.debug('%s already handled.', str(handled)) for i in range(0, len(srcls)): if srcls[i].startswith( 'def '): # a function header is detected in line i f_header = srcls[i].replace(" ", '') fn = f_header[len('def'):f_header.find('(')] # function name if "%s%s" % (srcfile_str, fn) not in methods: + logger.info( + '[file:%s, function:%s] not in methods list, skip it.', + srcfile_str, fn) continue if fn in handled: continue if fn in alllist: api_count += 1 if fn in wlist or fn + "@" + srcfile.name in wlist: + logger.info('[file:%s, function:%s] skip by wlist.', + srcfile_str, fn) continue fcombody = single_defcom_extract(i, srcls) if fcombody == "": # if no comment - print_header("def", fn) + print("def name:", fn) + print("-----------------------") print("WARNING: no comments in function ", fn, ", but it deserves.") continue else: - if not sampcd_extract_and_run(fcombody, fn, "def", fn): + result, _, _ = sampcd_extract_and_run(fcombody, fn, + "def", fn) + if not result: + error_methods.append(fn) process_result = False if srcls[i].startswith('class '): c_header = srcls[i].replace(" ", '') cn = c_header[len('class'):c_header.find('(')] # class name if '%s%s' % (srcfile_str, cn) not in methods: + logger.info( + '[file:%s, class:%s] not in methods list, skip it.', + srcfile_str, cn) continue if cn in handled: continue if cn in alllist: api_count += 1 if cn in wlist or cn + "@" + srcfile.name in wlist: + logger.info('[file:%s, class:%s] skip by wlist.', + srcfile_str, cn) continue # class comment classcom = single_defcom_extract(i, srcls, True) if classcom != "": - if not sampcd_extract_and_run(classcom, cn, "class", - cn): - + result, _, _ = sampcd_extract_and_run(classcom, cn, + "class", cn) + if not result: + error_methods.append(cn) process_result = False else: print("WARNING: no comments in class itself ", cn, @@ -410,10 +469,19 @@ def srccoms_extract(srcfile, wlist): if '%s%s' % ( srcfile_str, name ) not in methods: # class method not in api.spec + logger.info( + '[file:%s, func:%s] not in methods, skip it.', + srcfile_str, name) continue if mn.startswith('_'): + logger.info( + '[file:%s, func:%s] startswith _, it\'s private method, skip it.', + srcfile_str, name) continue if name in wlist or name + "@" + srcfile.name in wlist: + logger.info( + '[file:%s, class:%s] skip by wlist.', + srcfile_str, name) continue thismethod = [thisl[indent:] ] # method body lines @@ -434,22 +502,38 @@ def srccoms_extract(srcfile, wlist): thismtdcom = single_defcom_extract(0, thismethod) if thismtdcom != "": - if not sampcd_extract_and_run( - thismtdcom, name, "method", name): + result, _, _ = sampcd_extract_and_run( + thismtdcom, name, "method", name) + if not result: + error_methods.append(name) process_result = False + else: + logger.warning('__all__ not found in file:%s', srcfile.name) - return process_result + return process_result, error_methods def test(file_list): + global methods # readonly process_result = True for file in file_list: with open(file, 'r') as src: - if not srccoms_extract(src, wlist): + if not srccoms_extract(src, wlist, methods): process_result = False return process_result +def run_a_test(tc_filename): + """ + execute a sample code-block. + """ + global methods # readonly + process_result = True + with open(tc_filename, 'r') as src: + process_result, error_methods = srccoms_extract(src, wlist, methods) + return process_result, tc_filename, error_methods + + def get_filenames(): ''' this function will get the modules that pending for check. @@ -460,12 +544,12 @@ def get_filenames(): ''' filenames = [] - global methods + global methods # write global whl_error methods = [] whl_error = [] get_incrementapi() - API_spec = 'dev_pr_diff_api.spec' + API_spec = API_DIFF_SPEC_FN with open(API_spec) as f: for line in f.readlines(): api = line.replace('\n', '') @@ -474,17 +558,30 @@ def get_filenames(): except AttributeError: whl_error.append(api) continue + except SyntaxError: + logger.warning('line:%s, api:%s', line, api) + # paddle.Tensor. + continue if len(module.split('.')) > 1: filename = '../python/' + # work for .so? module_py = '%s.py' % module.split('.')[-1] for i in range(0, len(module.split('.')) - 1): filename = filename + '%s/' % module.split('.')[i] filename = filename + module_py else: filename = '' - print("\nWARNING:----Exception in get api filename----\n") - print("\n" + api + ' module is ' + module + "\n") - if filename != '' and filename not in filenames: + logger.warning("WARNING: Exception in getting api:%s module:%s", + api, module) + if filename in filenames: + continue + elif not filename: + logger.warning('filename invalid: %s', line) + continue + elif not os.path.exists(filename): + logger.warning('file not exists: %s', filename) + continue + else: filenames.append(filename) # get all methods method = '' @@ -496,9 +593,9 @@ def get_filenames(): name = '%s.%s' % (api.split('.')[-2], api.split('.')[-1]) else: name = '' - print("\nWARNING:----Exception in get api methods----\n") - print("\n" + line + "\n") - print("\n" + api + ' method is None!!!' + "\n") + logger.warning( + "WARNING: Exception when getting api:%s, line:%s", api, + line) for j in range(2, len(module.split('.'))): method = method + '%s.' % module.split('.')[j] method = method + name @@ -508,26 +605,27 @@ def get_filenames(): return filenames +def get_api_md5(path): + api_md5 = {} + API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")), + path) + with open(API_spec) as f: + for line in f.readlines(): + api = line.split(' ', 1)[0] + md5 = line.split("'document', ")[1].replace(')', '').replace('\n', + '') + api_md5[api] = md5 + return api_md5 + + def get_incrementapi(): ''' this function will get the apis that difference between API_DEV.spec and API_PR.spec. ''' - - def get_api_md5(path): - api_md5 = {} - API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")), - path) - with open(API_spec) as f: - for line in f.readlines(): - api = line.split(' ', 1)[0] - md5 = line.split("'document', ")[1].replace(')', '').replace( - '\n', '') - api_md5[api] = md5 - return api_md5 - - dev_api = get_api_md5('paddle/fluid/API_DEV.spec') - pr_api = get_api_md5('paddle/fluid/API_PR.spec') - with open('dev_pr_diff_api.spec', 'w') as f: + global API_DEV_SPEC_FN, API_PR_SPEC_FN, API_DIFF_SPEC_FN ## readonly + dev_api = get_api_md5(API_DEV_SPEC_FN) + pr_api = get_api_md5(API_PR_SPEC_FN) + with open(API_DIFF_SPEC_FN, 'w') as f: for key in pr_api: if key in dev_api: if dev_api[key] != pr_api[key]: @@ -538,7 +636,7 @@ def get_incrementapi(): f.write('\n') -def get_wlist(): +def get_wlist(fn="wlist.json"): ''' this function will get the white list of API. @@ -551,7 +649,7 @@ def get_wlist(): wlist_file = [] # only white on CPU gpu_not_white = [] - with open("wlist.json", 'r') as load_f: + with open(fn, 'r') as load_f: load_dict = json.load(load_f) for key in load_dict: if key == 'wlist_dir': @@ -567,31 +665,77 @@ def get_wlist(): return wlist, wlist_file, gpu_not_white -wlist, wlist_file, gpu_not_white = get_wlist() +arguments = [ + # flags, dest, type, default, help + ['--gpu_id', 'gpu_id', int, 0, 'GPU device id to use [0]'], + ['--logf', 'logf', str, None, 'file for logging'], + ['--threads', 'threads', int, 0, 'sub processes number'], +] -if len(sys.argv) < 2: - print("Error: inadequate number of arguments") - print('''If you are going to run it on - "CPU: >>> python sampcd_processor.py cpu - "GPU: >>> python sampcd_processor.py gpu - ''') - sys.exit("lack arguments") -else: - if sys.argv[1] == "gpu": + +def parse_args(): + """ + Parse input arguments + """ + global arguments + parser = argparse.ArgumentParser(description='run Sample Code Test') + # parser.add_argument('--cpu', dest='cpu_mode', action="store_true", + # help='Use CPU mode (overrides --gpu)') + # parser.add_argument('--gpu', dest='gpu_mode', action="store_true") + parser.add_argument('--debug', dest='debug', action="store_true") + parser.add_argument('mode', type=str, help='run on device', default='cpu') + for item in arguments: + parser.add_argument( + item[0], dest=item[1], help=item[4], type=item[2], default=item[3]) + + if len(sys.argv) == 1: + args = parser.parse_args(['cpu']) + return args + # parser.print_help() + # sys.exit(1) + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + if args.debug: + logger.setLevel(logging.DEBUG) + if args.logf: + logfHandler = logging.FileHandler(args.logf) + logfHandler.setFormatter( + logging.Formatter( + "%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s" + )) + logger.addHandler(logfHandler) + + wlist, wlist_file, gpu_not_white = get_wlist() + + if args.mode == "gpu": + GPU_ID = args.gpu_id + logger.info("using GPU_ID %d", GPU_ID) for _gnw in gpu_not_white: wlist.remove(_gnw) - elif sys.argv[1] != "cpu": - print("Unrecognized argument:'", sys.argv[1], "' , 'cpu' or 'gpu' is ", - "desired\n") + elif args.mode != "cpu": + logger.error("Unrecognized argument:%s, 'cpu' or 'gpu' is desired.", + args.mode) sys.exit("Invalid arguments") - print("API check -- Example Code") - print("sample_test running under python", platform.python_version()) - if not os.path.isdir("./samplecode_temp"): - os.mkdir("./samplecode_temp") - cpus = multiprocessing.cpu_count() + RUN_ON_DEVICE = args.mode + logger.info("API check -- Example Code") + logger.info("sample_test running under python %s", + platform.python_version()) + + if os.path.exists(SAMPLECODE_TEMPDIR): + if not os.path.isdir(SAMPLECODE_TEMPDIR): + os.remove(SAMPLECODE_TEMPDIR) + os.mkdir(SAMPLECODE_TEMPDIR) + else: + os.mkdir(SAMPLECODE_TEMPDIR) + filenames = get_filenames() if len(filenames) == 0 and len(whl_error) == 0: - print("-----API_PR.spec is the same as API_DEV.spec-----") + logger.info("-----API_PR.spec is the same as API_DEV.spec-----") exit(0) rm_file = [] for f in filenames: @@ -600,51 +744,52 @@ else: rm_file.append(f) filenames.remove(f) if len(rm_file) != 0: - print("REMOVE white files: %s" % rm_file) - print("API_PR is diff from API_DEV: %s" % filenames) - one_part_filenum = int(math.ceil(len(filenames) / cpus)) - if one_part_filenum == 0: - one_part_filenum = 1 - divided_file_list = [ - filenames[i:i + one_part_filenum] - for i in range(0, len(filenames), one_part_filenum) - ] - - po = multiprocessing.Pool() - results = po.map_async(test, divided_file_list) + logger.info("REMOVE white files: %s", rm_file) + logger.info("API_PR is diff from API_DEV: %s", filenames) + + threads = multiprocessing.cpu_count() + if args.threads: + threads = args.threads + po = multiprocessing.Pool(threads) + # results = po.map_async(test, divided_file_list) + results = po.map_async(run_a_test, filenames) po.close() po.join() result = results.get() # delete temp files - for root, dirs, files in os.walk("./samplecode_temp"): - for fntemp in files: - os.remove("./samplecode_temp/" + fntemp) - os.rmdir("./samplecode_temp") + if not args.debug: + shutil.rmtree(SAMPLECODE_TEMPDIR) - print("----------------End of the Check--------------------") + logger.info("----------------End of the Check--------------------") if len(whl_error) != 0: - print("%s is not in whl." % whl_error) - print("") - print("Please check the whl package and API_PR.spec!") - print("You can follow these steps in order to generate API.spec:") - print("1. cd ${paddle_path}, compile paddle;") - print("2. pip install build/python/dist/(build whl package);") - print( + logger.info("%s is not in whl.", whl_error) + logger.info("") + logger.info("Please check the whl package and API_PR.spec!") + logger.info("You can follow these steps in order to generate API.spec:") + logger.info("1. cd ${paddle_path}, compile paddle;") + logger.info("2. pip install build/python/dist/(build whl package);") + logger.info( "3. run 'python tools/print_signatures.py paddle > paddle/fluid/API.spec'." ) for temp in result: - if not temp: - print("") - print("In addition, mistakes found in sample codes.") - print("Please check sample codes.") - print("----------------------------------------------------") + if not temp[0]: + logger.info("In addition, mistakes found in sample codes: %s", + temp[1]) + logger.info("error_methods: %s", str(temp[2])) + logger.info("----------------------------------------------------") exit(1) else: + has_error = False for temp in result: - if not temp: - print("Mistakes found in sample codes.") - print("Please check sample codes.") - exit(1) - print("Sample code check is successful!") + if not temp[0]: + logger.info("In addition, mistakes found in sample codes: %s", + temp[1]) + logger.info("error_methods: %s", str(temp[2])) + has_error = True + if has_error: + logger.info("Mistakes found in sample codes.") + logger.info("Please check sample codes.") + exit(1) + logger.info("Sample code check is successful!") diff --git a/tools/test_sampcd_processor.py b/tools/test_sampcd_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f47d1af5815d1ad4d730e418d80d1c1e134c4d --- /dev/null +++ b/tools/test_sampcd_processor.py @@ -0,0 +1,439 @@ +#! python + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import tempfile +import shutil +import sys +import importlib +from sampcd_processor import find_all +from sampcd_processor import check_indent +from sampcd_processor import sampcd_extract_and_run +from sampcd_processor import single_defcom_extract +from sampcd_processor import srccoms_extract +from sampcd_processor import get_api_md5 +from sampcd_processor import get_incrementapi +from sampcd_processor import get_wlist + + +class Test_find_all(unittest.TestCase): + def test_find_none(self): + self.assertEqual(0, len(find_all('hello', 'world'))) + + def test_find_one(self): + self.assertListEqual([0], find_all('hello', 'hello')) + + def test_find_two(self): + self.assertListEqual([1, 15], + find_all(' hello, world; hello paddle!', 'hello')) + + +class Test_check_indent(unittest.TestCase): + def test_no_indent(self): + self.assertEqual(0, check_indent('hello paddle')) + + def test_indent_4_spaces(self): + self.assertEqual(4, check_indent(' hello paddle')) + + def test_indent_1_tab(self): + self.assertEqual(4, check_indent("\thello paddle")) + + +class Test_sampcd_extract_and_run(unittest.TestCase): + def setUp(self): + if not os.path.exists('samplecode_temp/'): + os.mkdir('samplecode_temp/') + + def test_run_a_defs_samplecode(self): + comments = """ + Examples: + .. code-block:: python + print(1+1) + """ + funcname = 'one_plus_one' + res, name, msg = sampcd_extract_and_run(comments, funcname) + self.assertTrue(res) + self.assertEqual(funcname, name) + + def test_run_a_def_no_code(self): + comments = """ + placeholder + """ + funcname = 'one_plus_one' + res, name, msg = sampcd_extract_and_run(comments, funcname) + self.assertFalse(res) + self.assertEqual(funcname, name) + + def test_run_a_def_raise_expection(self): + comments = """ + placeholder + Examples: + .. code-block:: python + print(1/0) + """ + funcname = 'one_plus_one' + res, name, msg = sampcd_extract_and_run(comments, funcname) + self.assertFalse(res) + self.assertEqual(funcname, name) + + +class Test_single_defcom_extract(unittest.TestCase): + def test_extract_from_func(self): + defstr = ''' +import os +def foo(): + """ + foo is a function. + """ + pass +def bar(): + pass +''' + comm = single_defcom_extract( + 2, defstr.splitlines(True), is_class_begin=False) + self.assertEqual(" foo is a function.\n", comm) + pass + + def test_extract_from_func_with_no_docstring(self): + defstr = ''' +import os +def bar(): + pass +''' + comm = single_defcom_extract( + 2, defstr.splitlines(True), is_class_begin=False) + self.assertEqual('', comm) + pass + + def test_extract_from_class(self): + defstr = r''' +import os +class Foo(): + """ + Foo is a class. + second line. + """ + pass + def bar(): + pass +def foo(): + pass +''' + comm = single_defcom_extract( + 2, defstr.splitlines(True), is_class_begin=True) + rcomm = """ Foo is a class. + second line. +""" + self.assertEqual(rcomm, comm) + pass + + def test_extract_from_class_with_no_docstring(self): + defstr = ''' +import os +class Foo(): + pass + def bar(): + pass +def foo(): + pass +''' + comm = single_defcom_extract( + 0, defstr.splitlines(True), is_class_begin=True) + self.assertEqual('', comm) + + +class Test_get_api_md5(unittest.TestCase): + def setUp(self): + self.api_pr_spec_filename = os.path.abspath( + os.path.join(os.getcwd(), "..", 'paddle/fluid/API_PR.spec')) + with open(self.api_pr_spec_filename, 'w') as f: + f.write("\n".join([ + """one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of one_plus_one'))""", + """two_plus_two (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of two_plus_two'))""", + """three_plus_three (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of three_plus_three'))""", + """four_plus_four (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of four_plus_four'))""", + ])) + + def tearDown(self): + os.remove(self.api_pr_spec_filename) + pass + + def test_get_api_md5(self): + res = get_api_md5('paddle/fluid/API_PR.spec') + self.assertEqual("'md5sum of one_plus_one'", res['one_plus_one']) + self.assertEqual("'md5sum of two_plus_two'", res['two_plus_two']) + self.assertEqual("'md5sum of three_plus_three'", + res['three_plus_three']) + self.assertEqual("'md5sum of four_plus_four'", res['four_plus_four']) + + +class Test_get_incrementapi(unittest.TestCase): + def setUp(self): + self.api_pr_spec_filename = os.path.abspath( + os.path.join(os.getcwd(), "..", 'paddle/fluid/API_PR.spec')) + with open(self.api_pr_spec_filename, 'w') as f: + f.write("\n".join([ + """one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of one_plus_one'))""", + """two_plus_two (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of two_plus_two'))""", + """three_plus_three (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of three_plus_three'))""", + """four_plus_four (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of four_plus_four'))""", + ])) + self.api_dev_spec_filename = os.path.abspath( + os.path.join(os.getcwd(), "..", 'paddle/fluid/API_DEV.spec')) + with open(self.api_dev_spec_filename, 'w') as f: + f.write("\n".join([ + """one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', 'md5sum of one_plus_one'))""", + ])) + self.api_diff_spec_filename = os.path.abspath( + os.path.join(os.getcwd(), "dev_pr_diff_api.spec")) + + def tearDown(self): + os.remove(self.api_pr_spec_filename) + os.remove(self.api_dev_spec_filename) + os.remove(self.api_diff_spec_filename) + + def test_it(self): + get_incrementapi() + with open(self.api_diff_spec_filename, 'r') as f: + lines = f.readlines() + self.assertCountEqual( + ["two_plus_two\n", "three_plus_three\n", "four_plus_four\n"], + lines) + + +class Test_get_wlist(unittest.TestCase): + def setUp(self): + self.tmpDir = tempfile.mkdtemp() + self.wlist_filename = os.path.join(self.tmpDir, 'wlist.json') + with open(self.wlist_filename, 'w') as f: + f.write(r''' +{ + "wlist_dir":[ + { + "name":"../python/paddle/fluid/contrib", + "annotation":"" + }, + { + "name":"../python/paddle/verison.py", + "annotation":"" + } + ], + "wlist_api":[ + { + "name":"xxxxx", + "annotation":"not a real api, just for example" + } + ], + "wlist_temp_api":[ + "to_tensor", + "save_persistables@dygraph/checkpoint.py" + ], + "gpu_not_white":[ + "deformable_conv" + ] +} +''') + + def tearDown(self): + os.remove(self.wlist_filename) + shutil.rmtree(self.tmpDir) + + def test_get_wlist(self): + wlist, wlist_file, gpu_not_white = get_wlist(self.wlist_filename) + self.assertCountEqual( + ["xxxxx", "to_tensor", + "save_persistables@dygraph/checkpoint.py"], wlist) + self.assertCountEqual([ + "../python/paddle/fluid/contrib", + "../python/paddle/verison.py", + ], wlist_file) + self.assertCountEqual(["deformable_conv"], gpu_not_white) + + +class Test_srccoms_extract(unittest.TestCase): + def setUp(self): + self.tmpDir = tempfile.mkdtemp() + sys.path.append(self.tmpDir) + self.api_pr_spec_filename = os.path.abspath( + os.path.join(os.getcwd(), "..", 'paddle/fluid/API_PR.spec')) + with open(self.api_pr_spec_filename, 'w') as f: + f.write("\n".join([ + """one_plus_one (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "one_plus_one"))""", + """two_plus_two (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "two_plus_two"))""", + """three_plus_three (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "three_plus_three"))""", + """four_plus_four (ArgSpec(args=[], varargs=None, keywords=None, defaults=(,)), ('document', "four_plus_four"))""", + ])) + + def tearDown(self): + sys.path.remove(self.tmpDir) + shutil.rmtree(self.tmpDir) + os.remove(self.api_pr_spec_filename) + + def test_from_ops_py(self): + filecont = ''' +def add_sample_code(obj, docstr): + pass + +__unary_func__ = [ + 'exp', +] + +__all__ = [] +__all__ += __unary_func__ +__all__ += ['one_plus_one'] + +def exp(): + pass +add_sample_code(globals()["exp"], r""" +Examples: + .. code-block:: python + import paddle + x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3]) + out = paddle.exp(x) + print(out) + # [0.67032005 0.81873075 1.10517092 1.34985881] +""") + +def one_plus_one(): + return 1+1 + +one_plus_one.__doc__ = """ + placeholder + + Examples: + .. code-block:: python + print(1+1) +""" + +__all__ += ['two_plus_two'] +def two_plus_two(): + return 2+2 +add_sample_code(globals()["two_plus_two"], """ + Examples: + .. code-block:: python + print(2+2) +""") +''' + pyfilename = os.path.join(self.tmpDir, 'ops.py') + with open(pyfilename, 'w') as pyfile: + pyfile.write(filecont) + self.assertTrue(os.path.exists(pyfilename)) + utsp = importlib.import_module('ops') + print('testing srccoms_extract from ops.py') + methods = ['one_plus_one', 'two_plus_two', 'exp'] + # os.remove("samplecode_temp/" "one_plus_one_example.py") + self.assertFalse( + os.path.exists("samplecode_temp/" + "one_plus_one_example.py")) + with open(pyfilename, 'r') as pyfile: + res, error_methods = srccoms_extract(pyfile, [], methods) + self.assertTrue(res) + self.assertTrue( + os.path.exists("samplecode_temp/" + "one_plus_one_example.py")) + os.remove("samplecode_temp/" "one_plus_one_example.py") + self.assertTrue( + os.path.exists("samplecode_temp/" + "two_plus_two_example.py")) + os.remove("samplecode_temp/" "two_plus_two_example.py") + self.assertTrue(os.path.exists("samplecode_temp/" "exp_example.py")) + os.remove("samplecode_temp/" "exp_example.py") + + def test_from_not_ops_py(self): + filecont = ''' +__all__ = [ + 'one_plus_one' +] + +def one_plus_one(): + """ + placeholder + + Examples: + .. code-block:: python + print(1+1) + """ + return 1+1 + +''' + pyfilename = os.path.join(self.tmpDir, 'opo.py') + with open(pyfilename, 'w') as pyfile: + pyfile.write(filecont) + utsp = importlib.import_module('opo') + methods = ['one_plus_one'] + with open(pyfilename, 'r') as pyfile: + res, error_methods = srccoms_extract(pyfile, [], methods) + self.assertTrue(res) + self.assertTrue( + os.path.exists("samplecode_temp/" + "one_plus_one_example.py")) + os.remove("samplecode_temp/" "one_plus_one_example.py") + + def test_with_empty_wlist(self): + """ + see test_from_ops_py + """ + pass + + def test_with_wlist(self): + filecont = ''' +__all__ = [ + 'four_plus_four', + 'three_plus_three' + ] + +def four_plus_four(): + """ + placeholder + + Examples: + .. code-block:: python + print(4+4) + """ + return 4+4 +def three_plus_three(): + """ + placeholder + + Examples: + .. code-block:: python + print(3+3) + """ + return 3+3 + +''' + pyfilename = os.path.join(self.tmpDir, 'three_and_four.py') + with open(pyfilename, 'w') as pyfile: + pyfile.write(filecont) + utsp = importlib.import_module('three_and_four') + methods = ['four_plus_four', 'three_plus_three'] + with open(pyfilename, 'r') as pyfile: + res, error_methods = srccoms_extract(pyfile, ['three_plus_three'], + methods) + self.assertTrue(res) + self.assertTrue( + os.path.exists("samplecode_temp/four_plus_four_example.py")) + os.remove("samplecode_temp/" "four_plus_four_example.py") + self.assertFalse( + os.path.exists("samplecode_temp/three_plus_three_example.py")) + + +# https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/layers/ops.py +# why? unabled to use the ast module. emmmmm + +if __name__ == '__main__': + unittest.main()