sampcd_processor.py 29.7 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
T
tianshuo78520a 已提交
14 15

import os
16
import sys
T
tianshuo78520a 已提交
17
import subprocess
18 19 20
import multiprocessing
import math
import platform
21 22 23
import inspect
import paddle
import paddle.fluid
Z
zhangchunle 已提交
24
import json
25 26 27 28
import argparse
import shutil
import re
import logging
29 30
"""
please make sure to run in the tools path
31
usage: python sample_test.py {arg1} 
32 33 34
arg1: the first arg defined running in gpu version or cpu version

for example, you can run cpu version python2 testing like this:
35 36 37

    python sampcd_processor.py cpu 

38
"""
T
tianshuo78520a 已提交
39

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
logger = logging.getLogger()
if logger.handlers:
    console = logger.handlers[
        0]  # we assume the first handler is the one we want to configure
else:
    console = logging.StreamHandler()
    logger.addHandler(console)
console.setFormatter(
    logging.Formatter(
        "%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"))

RUN_ON_DEVICE = 'cpu'
GPU_ID = 0
methods = []
whl_error = []
API_DEV_SPEC_FN = 'paddle/fluid/API_DEV.spec'
API_PR_SPEC_FN = 'paddle/fluid/API_PR.spec'
API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec'
SAMPLECODE_TEMPDIR = 'samplecode_temp'

T
tianshuo78520a 已提交
60 61

def find_all(srcstr, substr):
62
    """
63 64 65 66 67 68
    to find all desired substring in the source string
     and return their starting indices as a list

    Args:
        srcstr(str): the parent string
        substr(str): substr
69

70
    Returns:
71
        list: a list of the indices of the substrings
72
              found
73
    """
T
tianshuo78520a 已提交
74 75 76 77 78 79 80 81 82
    indices = []
    gotone = srcstr.find(substr)
    while (gotone != -1):
        indices.append(gotone)
        gotone = srcstr.find(substr, gotone + 1)
    return indices


def check_indent(cdline):
83
    """
84
    to check the indent of a given code line
85

86 87
    to get the number of starting blank chars,
    e.t. blankspaces and \t
88 89

    \t will be interpreted as 4 single blankspaces,
90
    e.t. '\t'='    '
91

92 93 94 95
    Args:
        cdline(str) : a single line of code from the source file

    Returns:
96
        int : the indent of the number of interpreted
97
             blankspaces
98
    """
T
tianshuo78520a 已提交
99 100 101 102 103 104 105 106 107 108 109
    indent = 0
    for c in cdline:
        if c == '\t':
            indent += 4
        elif c == ' ':
            indent += 1
        if c != ' ' and c != '\t':
            break
    return indent


110 111 112
# srccom: raw comments in the source,including ''' and original indent
def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
    """
113 114 115 116 117 118 119 120 121
    Extract and run sample codes from source comment and
    the result will be returned.

    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
122

123
    Returns:
124
        result: True or False
125 126
        name(str): the name of the API.
        msg(str): messages
127
    """
128
    global GPU_ID, RUN_ON_DEVICE, SAMPLECODE_TEMPDIR
129

130
    result = True
131
    msg = None
132

133 134
    def sampcd_header_print(name, sampcd, htype, hname):
        """
135
        print hint banner headers.
136

137 138 139 140 141 142
        Args:
            name(str): the name of the API.
            sampcd(str): sample code string
            htype(str): the type of hint banners, def/class/method.
            hname(str): the name of the hint  banners , e.t. def hname.
            flushed.
143
        """
144 145
        print(htype, " name:", hname)
        print("-----------------------")
146
        print("Sample code ", str(y), " extracted for ", name, "   :")
147
        print(sampcd)
148 149 150
        print("----example code check----\n")
        print("executing sample code .....")
        print("execution result:")
151 152

    sampcd_begins = find_all(srccom, " code-block:: python")
153
    if len(sampcd_begins) == 0:
154 155 156
        # detect sample codes using >>> to format and consider this situation as wrong
        print(htype, " name:", hname)
        print("-----------------------")
157 158 159
        if srccom.find("Examples:") != -1:
            print("----example code check----\n")
            if srccom.find(">>>") != -1:
T
tianshuo78520a 已提交
160
                print(
161 162
                    "Deprecated sample code style:\n\n    Examples:\n\n        >>>codeline\n        >>>codeline\n\n\n ",
                    "Please use '.. code-block:: python' to ",
T
tianshuo78520a 已提交
163
                    "format sample code.\n")
164
                result = False
T
tianshuo78520a 已提交
165
        else:
166 167
            print("Error: No sample code!\n")
            result = False
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188

    for y in range(1, len(sampcd_begins) + 1):
        sampcd_begin = sampcd_begins[y - 1]
        sampcd = srccom[sampcd_begin + len(" code-block:: python") + 1:]
        sampcd = sampcd.split("\n")
        # remove starting empty lines
        while sampcd[0].replace(' ', '').replace('\t', '') == '':
            sampcd.pop(0)

        # the minimum indent, which is the indent of the first
        # non-empty line
        min_indent = check_indent(sampcd[0])
        sampcd_to_write = []
        for i in range(0, len(sampcd)):
            cdline = sampcd[i]
            # handle empty lines or those only with spaces/tabs
            if cdline.strip() == '':
                continue
            this_indent = check_indent(cdline)
            if this_indent < min_indent:
                break
Z
zhangchunle 已提交
189
            else:
190 191 192 193
                cdline = cdline.replace('\t', '    ')
                sampcd_to_write.append(cdline[min_indent:])

        sampcd = '\n'.join(sampcd_to_write)
194 195 196 197 198
        if RUN_ON_DEVICE == "cpu":
            sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd
        if RUN_ON_DEVICE == "gpu":
            sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = "{}"\n'.format(
                GPU_ID) + sampcd
199 200
        sampcd += '\nprint(' + '\"' + name + ' sample code is executed successfully!\")'

201 202 203 204 205
        tfname = os.path.join(SAMPLECODE_TEMPDIR, '{}_example{}'.format(
            name, '.py' if len(sampcd_begins) == 1 else '_{}.py'.format(y)))
        logging.info('running %s', tfname)
        with open(tfname, 'w') as tempf:
            tempf.write(sampcd)
206
        if platform.python_version()[0] == "2":
207
            cmd = ["python", tfname]
208
        elif platform.python_version()[0] == "3":
209
            cmd = ["python3", tfname]
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
        else:
            print("Error: fail to parse python version!")
            result = False
            exit(1)

        subprc = subprocess.Popen(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        output, error = subprc.communicate()
        msg = "".join(output.decode(encoding='utf-8'))
        err = "".join(error.decode(encoding='utf-8'))

        if subprc.returncode != 0:
            print("\nSample code error found in ", name, ":\n")
            sampcd_header_print(name, sampcd, htype, hname)
            print("subprocess return code: ", str(subprc.returncode))
            print("Error Raised from Sample Code ", name, " :\n")
            print(err)
            print(msg)
228 229
            logging.warning('%s error: %s', tfname, err)
            logging.warning('%s msg: %s', tfname, msg)
230 231
            result = False
        # msg is the returned code execution report
232

233
    return result, name, msg
T
tianshuo78520a 已提交
234 235 236


def single_defcom_extract(start_from, srcls, is_class_begin=False):
237
    """
238 239
    to extract a def function/class/method comments body

240
    Args:
241 242 243 244 245 246
        start_from(int): the line num of "def" header
        srcls(list): the source file in lines
        is_class_begin(bool): whether the start_from is a beginning a class. \
        For a sole class body itself may end up with its method if it has no
        docstring. But the body of \
        a common def function can only be ended up by a none-indented def/class
247

248 249 250
    Returns:
        string : the extracted comment body, inclusive of its quote marks.

251
    """
252

T
tianshuo78520a 已提交
253
    i = start_from
254 255 256
    fcombody = ""  # def comment body
    comstart = -1  # the starting line index of comment mark "'''" or """"""
    # if it is not -1, it indicates the loop is in the comment body
257 258
    comstyle = 0  # comment mark style ,comments quoted with ''' is coded as 1
    # comments quoted with """ is coded as 2
T
tianshuo78520a 已提交
259 260
    for x in range(i + 1, len(srcls)):
        if is_class_begin:
261
            if srcls[x].replace('\t', '    ').startswith('    def '):
T
tianshuo78520a 已提交
262
                break
263
        if srcls[x].startswith('def ') or srcls[x].startswith('class '):
T
tianshuo78520a 已提交
264 265
            break
        else:
266 267 268 269 270 271 272
            if comstart == -1:
                s = srcls[x].replace(" ", '').replace("\t",
                                                      '').replace("\n", '')
                if s.startswith("\"\"\"") or s.startswith("r\"\"\""):
                    comstart = x
                    comstyle = 2
                    continue
T
tianshuo78520a 已提交
273 274 275 276
            if (comstyle == 2 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\"\"\"")):
                break
277 278 279 280 281 282 283
            if comstart == -1:
                s = srcls[x].replace(" ", '').replace("\t",
                                                      '').replace("\n", '')
                if s.startswith("\'\'\'") or s.startswith("r\'\'\'"):
                    comstart = x
                    comstyle = 1
                    continue
T
tianshuo78520a 已提交
284 285 286 287 288
            if (comstyle == 1 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\'\'\'")):
                break
            if (comstart !=
289
                    -1):  # when the comments start, begin to add line to fcombody
T
tianshuo78520a 已提交
290 291 292 293
                fcombody += srcls[x]
    return fcombody


294
def srccoms_extract(srcfile, wlist, methods):
295
    """
296 297 298 299 300 301 302
    Given a source file ``srcfile``, this function will
    extract its API(doc comments) and run sample codes in the
    API.

    Args:
        srcfile(file): the source file
        wlist(list): white list
303
        methods(list): only elements of this list considered.
304 305

    Returns:
306
        result: True or False
307
        error_methods: the methods that failed.
308
    """
309

310
    process_result = True
311
    error_methods = []
T
tianshuo78520a 已提交
312
    srcc = srcfile.read()
313 314
    # 2. get defs and classes header line number
    # set file pointer to its beginning
T
tianshuo78520a 已提交
315
    srcfile.seek(0, 0)
316
    srcls = srcfile.readlines()  # source lines
317

318
    # 1. fetch__all__ list
T
tianshuo78520a 已提交
319
    allidx = srcc.find("__all__")
320 321
    logger.debug('processing %s, methods: %s', srcfile.name, str(methods))
    srcfile_new, _ = os.path.splitext(srcfile.name)
322 323 324 325
    srcfile_list = srcfile_new.split('/')
    srcfile_str = ''
    for i in range(4, len(srcfile_list)):
        srcfile_str = srcfile_str + srcfile_list[i] + '.'
326
    if allidx != -1:
T
tianshuo78520a 已提交
327
        alllist = []
328 329
        # get all list for layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
330
            for ai in range(0, len(srcls)):
331
                if srcls[ai].startswith("__all__"):
T
tianshuo78520a 已提交
332 333
                    lb = srcls[ai].find('[')
                    rb = srcls[ai].find(']')
334
                    if lb == -1:
T
tianshuo78520a 已提交
335 336 337 338
                        continue
                    allele = srcls[ai][lb + 1:rb].replace("'", '').replace(
                        " ", '').replace("\"", '')
                    alllist.append(allele)
339 340
            if '' in alllist:
                alllist.remove('')
T
tianshuo78520a 已提交
341 342 343 344 345 346 347 348 349 350
        else:
            alllist_b = allidx + len("__all__")
            allstr = srcc[alllist_b + srcc[alllist_b:].find("[") + 1:alllist_b +
                          srcc[alllist_b:].find("]")]
            allstr = allstr.replace("\n", '').replace(" ", '').replace(
                "'", '').replace("\"", '')
            alllist = allstr.split(',')
            if '' in alllist:
                alllist.remove('')
        api_alllist_count = len(alllist)
351
        logger.debug('found %d items: %s', api_alllist_count, str(alllist))
T
tianshuo78520a 已提交
352 353
        api_count = 0
        handled = []
354 355
        # get src contents in layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
356
            for i in range(0, len(srcls)):
357 358 359 360 361 362 363 364 365 366
                opname = None
                opres = re.match(r"^(\w+)\.__doc__", srcls[i])
                if opres is not None:
                    opname = opres.group(1)
                else:
                    opres = re.match(
                        r"^add_sample_code\(globals\(\)\[\"(\w+)\"\]", srcls[i])
                    if opres is not None:
                        opname = opres.group(1)
                if opname is not None:
T
tianshuo78520a 已提交
367
                    if opname in wlist:
368
                        logger.info('%s is in the whitelist, skip it.', opname)
T
tianshuo78520a 已提交
369
                        continue
370 371
                    else:
                        logger.debug('%s\'s docstring found.', opname)
T
tianshuo78520a 已提交
372 373
                    comstart = i
                    for j in range(i, len(srcls)):
374
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
375 376 377 378
                            comstart = i
                    opcom = ""
                    for j in range(comstart + 1, len(srcls)):
                        opcom += srcls[j]
379
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
380
                            break
381 382 383 384 385
                    result, _, _ = sampcd_extract_and_run(opcom, opname, "def",
                                                          opname)
                    if not result:
                        error_methods.append(opname)
                        process_result = False
T
tianshuo78520a 已提交
386
                    api_count += 1
387
                    handled.append(
388 389 390
                        opname)  # ops.py also has normal formatted functions
                    # use list 'handled'  to mark the functions have been handled here
                    # which will be ignored in the following step
391 392
                    # handled what?
        logger.debug('%s already handled.', str(handled))
T
tianshuo78520a 已提交
393
        for i in range(0, len(srcls)):
394
            if srcls[i].startswith(
395
                    'def '):  # a function header is detected in line i
T
tianshuo78520a 已提交
396
                f_header = srcls[i].replace(" ", '')
397
                fn = f_header[len('def'):f_header.find('(')]  # function name
398
                if "%s%s" % (srcfile_str, fn) not in methods:
399 400 401
                    logger.info(
                        '[file:%s, function:%s] not in methods list, skip it.',
                        srcfile_str, fn)
402
                    continue
T
tianshuo78520a 已提交
403 404 405 406
                if fn in handled:
                    continue
                if fn in alllist:
                    api_count += 1
407
                    if fn in wlist or fn + "@" + srcfile.name in wlist:
408 409
                        logger.info('[file:%s, function:%s] skip by wlist.',
                                    srcfile_str, fn)
T
tianshuo78520a 已提交
410 411
                        continue
                    fcombody = single_defcom_extract(i, srcls)
412
                    if fcombody == "":  # if no comment
413 414
                        print("def name:", fn)
                        print("-----------------------")
415 416
                        print("WARNING: no comments in function ", fn,
                              ", but it deserves.")
T
tianshuo78520a 已提交
417 418
                        continue
                    else:
419 420 421 422
                        result, _, _ = sampcd_extract_and_run(fcombody, fn,
                                                              "def", fn)
                        if not result:
                            error_methods.append(fn)
423
                            process_result = False
424

T
tianshuo78520a 已提交
425 426
            if srcls[i].startswith('class '):
                c_header = srcls[i].replace(" ", '')
427
                cn = c_header[len('class'):c_header.find('(')]  # class name
428
                if '%s%s' % (srcfile_str, cn) not in methods:
429 430 431
                    logger.info(
                        '[file:%s, class:%s] not in methods list, skip it.',
                        srcfile_str, cn)
432
                    continue
T
tianshuo78520a 已提交
433 434 435 436
                if cn in handled:
                    continue
                if cn in alllist:
                    api_count += 1
437
                    if cn in wlist or cn + "@" + srcfile.name in wlist:
438 439
                        logger.info('[file:%s, class:%s] skip by wlist.',
                                    srcfile_str, cn)
T
tianshuo78520a 已提交
440
                        continue
441
                    # class comment
T
tianshuo78520a 已提交
442
                    classcom = single_defcom_extract(i, srcls, True)
443
                    if classcom != "":
444 445 446 447
                        result, _, _ = sampcd_extract_and_run(classcom, cn,
                                                              "class", cn)
                        if not result:
                            error_methods.append(cn)
448
                            process_result = False
T
tianshuo78520a 已提交
449
                    else:
450 451 452
                        print("WARNING: no comments in class itself ", cn,
                              ", but it deserves.\n")
                    # handling methods in class bodies
T
tianshuo78520a 已提交
453 454
                    for x in range(
                            i + 1,
455
                            len(srcls)):  # from the next line of class header
T
tianshuo78520a 已提交
456 457 458 459
                        if (srcls[x].startswith('def ') or
                                srcls[x].startswith('class ')):
                            break
                        else:
460
                            # member method def header
461
                            srcls[x] = srcls[x].replace('\t', '    ')
T
tianshuo78520a 已提交
462
                            if (srcls[x].startswith(
463
                                    '    def ')):  # detect a mehtod header..
T
tianshuo78520a 已提交
464 465 466
                                thisl = srcls[x]
                                indent = len(thisl) - len(thisl.lstrip())
                                mn = thisl[indent + len('def '):thisl.find(
467 468
                                    '(')]  # method name
                                name = cn + "." + mn  # full name
469 470 471
                                if '%s%s' % (
                                        srcfile_str, name
                                ) not in methods:  # class method not in api.spec 
472 473 474
                                    logger.info(
                                        '[file:%s, func:%s] not in methods, skip it.',
                                        srcfile_str, name)
475
                                    continue
T
tianshuo78520a 已提交
476
                                if mn.startswith('_'):
477 478 479
                                    logger.info(
                                        '[file:%s, func:%s] startswith _, it\'s private method, skip it.',
                                        srcfile_str, name)
T
tianshuo78520a 已提交
480
                                    continue
481
                                if name in wlist or name + "@" + srcfile.name in wlist:
482 483 484
                                    logger.info(
                                        '[file:%s, class:%s] skip by wlist.',
                                        srcfile_str, name)
T
tianshuo78520a 已提交
485
                                    continue
486 487 488 489 490
                                thismethod = [thisl[indent:]
                                              ]  # method body lines
                                # get all the lines of a single method body
                                # into thismethod(list)
                                # and send it to single_defcom_extract
T
tianshuo78520a 已提交
491
                                for y in range(x + 1, len(srcls)):
492
                                    srcls[y] = srcls[y].replace('\t', '    ')
T
tianshuo78520a 已提交
493 494
                                    if (srcls[y].startswith('def ') or
                                            srcls[y].startswith('class ')):
495
                                        # end of method
T
tianshuo78520a 已提交
496
                                        break
497 498
                                    elif srcls[y].startswith('    def '):
                                        # end of method
T
tianshuo78520a 已提交
499 500 501 502 503
                                        break
                                    else:
                                        thismethod.append(srcls[y][indent:])
                                thismtdcom = single_defcom_extract(0,
                                                                   thismethod)
504
                                if thismtdcom != "":
505 506 507 508
                                    result, _, _ = sampcd_extract_and_run(
                                        thismtdcom, name, "method", name)
                                    if not result:
                                        error_methods.append(name)
509
                                        process_result = False
510 511
    else:
        logger.warning('__all__ not found in file:%s', srcfile.name)
512

513
    return process_result, error_methods
T
tianshuo78520a 已提交
514 515


516
def test(file_list):
517
    global methods  # readonly
518
    process_result = True
519
    for file in file_list:
520
        with open(file, 'r') as src:
521
            if not srccoms_extract(src, wlist, methods):
522 523
                process_result = False
    return process_result
524 525


526 527 528 529 530 531 532 533 534 535 536
def run_a_test(tc_filename):
    """
    execute a sample code-block.
    """
    global methods  # readonly
    process_result = True
    with open(tc_filename, 'r') as src:
        process_result, error_methods = srccoms_extract(src, wlist, methods)
    return process_result, tc_filename, error_methods


537
def get_filenames():
538
    '''
539
    this function will get the modules that pending for check.
540 541 542 543 544 545 546

    Returns:

        list: the modules pending for check .

    '''
    filenames = []
547
    global methods  # write
548
    global whl_error
549
    methods = []
550
    whl_error = []
551
    get_incrementapi()
552
    API_spec = API_DIFF_SPEC_FN
553 554
    with open(API_spec) as f:
        for line in f.readlines():
555
            api = line.replace('\n', '')
556 557 558
            try:
                module = eval(api).__module__
            except AttributeError:
559
                whl_error.append(api)
560
                continue
561 562 563 564
            except SyntaxError:
                logger.warning('line:%s, api:%s', line, api)
                # paddle.Tensor.<lambda>
                continue
565
            if len(module.split('.')) > 1:
566
                filename = '../python/'
567
                # work for .so?
568 569 570 571 572
                module_py = '%s.py' % module.split('.')[-1]
                for i in range(0, len(module.split('.')) - 1):
                    filename = filename + '%s/' % module.split('.')[i]
                filename = filename + module_py
            else:
573
                filename = ''
574 575 576 577 578 579 580 581 582 583 584
                logger.warning("WARNING: Exception in getting api:%s module:%s",
                               api, module)
            if filename in filenames:
                continue
            elif not filename:
                logger.warning('filename invalid: %s', line)
                continue
            elif not os.path.exists(filename):
                logger.warning('file not exists: %s', filename)
                continue
            else:
Z
zhangchunle 已提交
585
                filenames.append(filename)
586 587 588 589 590 591 592 593 594 595
            # get all methods
            method = ''
            if inspect.isclass(eval(api)):
                name = api.split('.')[-1]
            elif inspect.isfunction(eval(api)):
                name = api.split('.')[-1]
            elif inspect.ismethod(eval(api)):
                name = '%s.%s' % (api.split('.')[-2], api.split('.')[-1])
            else:
                name = ''
596 597 598
                logger.warning(
                    "WARNING: Exception when getting api:%s, line:%s", api,
                    line)
599 600 601 602 603
            for j in range(2, len(module.split('.'))):
                method = method + '%s.' % module.split('.')[j]
            method = method + name
            if method not in methods:
                methods.append(method)
604
    os.remove(API_spec)
605 606 607
    return filenames


608 609 610 611 612 613 614 615 616 617 618 619 620
def get_api_md5(path):
    api_md5 = {}
    API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")),
                          path)
    with open(API_spec) as f:
        for line in f.readlines():
            api = line.split(' ', 1)[0]
            md5 = line.split("'document', ")[1].replace(')', '').replace('\n',
                                                                         '')
            api_md5[api] = md5
    return api_md5


621 622 623 624
def get_incrementapi():
    '''
    this function will get the apis that difference between API_DEV.spec and API_PR.spec.
    '''
625 626 627 628
    global API_DEV_SPEC_FN, API_PR_SPEC_FN, API_DIFF_SPEC_FN  ## readonly
    dev_api = get_api_md5(API_DEV_SPEC_FN)
    pr_api = get_api_md5(API_PR_SPEC_FN)
    with open(API_DIFF_SPEC_FN, 'w') as f:
629 630 631 632 633 634 635 636 637 638
        for key in pr_api:
            if key in dev_api:
                if dev_api[key] != pr_api[key]:
                    f.write(key)
                    f.write('\n')
            else:
                f.write(key)
                f.write('\n')


639
def get_wlist(fn="wlist.json"):
Z
zhangchunle 已提交
640 641 642 643 644 645 646 647 648
    '''
    this function will get the white list of API.

    Returns:

        wlist: a list of API that should not trigger the example check .

    '''
    wlist = []
Z
zhangchunle 已提交
649
    wlist_file = []
650 651
    # only white on CPU
    gpu_not_white = []
652
    with open(fn, 'r') as load_f:
Z
zhangchunle 已提交
653 654
        load_dict = json.load(load_f)
        for key in load_dict:
655 656 657 658 659 660 661 662
            if key == 'wlist_dir':
                for item in load_dict[key]:
                    wlist_file.append(item["name"])
            elif key == "gpu_not_white":
                gpu_not_white = load_dict[key]
            elif key == "wlist_api":
                for item in load_dict[key]:
                    wlist.append(item["name"])
Z
zhangchunle 已提交
663 664
            else:
                wlist = wlist + load_dict[key]
665
    return wlist, wlist_file, gpu_not_white
Z
zhangchunle 已提交
666 667


668 669 670 671 672 673
arguments = [
    # flags, dest, type, default, help
    ['--gpu_id', 'gpu_id', int, 0, 'GPU device id to use [0]'],
    ['--logf', 'logf', str, None, 'file for logging'],
    ['--threads', 'threads', int, 0, 'sub processes number'],
]
674

675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717

def parse_args():
    """
    Parse input arguments
    """
    global arguments
    parser = argparse.ArgumentParser(description='run Sample Code Test')
    # parser.add_argument('--cpu', dest='cpu_mode', action="store_true",
    #                     help='Use CPU mode (overrides --gpu)')
    # parser.add_argument('--gpu', dest='gpu_mode', action="store_true")
    parser.add_argument('--debug', dest='debug', action="store_true")
    parser.add_argument('mode', type=str, help='run on device', default='cpu')
    for item in arguments:
        parser.add_argument(
            item[0], dest=item[1], help=item[4], type=item[2], default=item[3])

    if len(sys.argv) == 1:
        args = parser.parse_args(['cpu'])
        return args
    #    parser.print_help()
    #    sys.exit(1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    if args.debug:
        logger.setLevel(logging.DEBUG)
    if args.logf:
        logfHandler = logging.FileHandler(args.logf)
        logfHandler.setFormatter(
            logging.Formatter(
                "%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"
            ))
        logger.addHandler(logfHandler)

    wlist, wlist_file, gpu_not_white = get_wlist()

    if args.mode == "gpu":
        GPU_ID = args.gpu_id
        logger.info("using GPU_ID %d", GPU_ID)
718 719
        for _gnw in gpu_not_white:
            wlist.remove(_gnw)
720 721 722
    elif args.mode != "cpu":
        logger.error("Unrecognized argument:%s, 'cpu' or 'gpu' is desired.",
                     args.mode)
723
        sys.exit("Invalid arguments")
724 725 726 727 728 729 730 731 732 733 734 735
    RUN_ON_DEVICE = args.mode
    logger.info("API check -- Example Code")
    logger.info("sample_test running under python %s",
                platform.python_version())

    if os.path.exists(SAMPLECODE_TEMPDIR):
        if not os.path.isdir(SAMPLECODE_TEMPDIR):
            os.remove(SAMPLECODE_TEMPDIR)
            os.mkdir(SAMPLECODE_TEMPDIR)
    else:
        os.mkdir(SAMPLECODE_TEMPDIR)

736
    filenames = get_filenames()
737
    if len(filenames) == 0 and len(whl_error) == 0:
738
        logger.info("-----API_PR.spec is the same as API_DEV.spec-----")
739
        exit(0)
Z
zhangchunle 已提交
740 741 742 743 744 745 746
    rm_file = []
    for f in filenames:
        for w_file in wlist_file:
            if f.startswith(w_file):
                rm_file.append(f)
                filenames.remove(f)
    if len(rm_file) != 0:
747 748 749 750 751 752 753 754 755
        logger.info("REMOVE white files: %s", rm_file)
    logger.info("API_PR is diff from API_DEV: %s", filenames)

    threads = multiprocessing.cpu_count()
    if args.threads:
        threads = args.threads
    po = multiprocessing.Pool(threads)
    # results = po.map_async(test, divided_file_list)
    results = po.map_async(run_a_test, filenames)
756 757
    po.close()
    po.join()
758

759
    result = results.get()
760

761
    # delete temp files
762 763
    if not args.debug:
        shutil.rmtree(SAMPLECODE_TEMPDIR)
764

765
    logger.info("----------------End of the Check--------------------")
766
    if len(whl_error) != 0:
767 768 769 770 771 772 773
        logger.info("%s is not in whl.", whl_error)
        logger.info("")
        logger.info("Please check the whl package and API_PR.spec!")
        logger.info("You can follow these steps in order to generate API.spec:")
        logger.info("1. cd ${paddle_path}, compile paddle;")
        logger.info("2. pip install build/python/dist/(build whl package);")
        logger.info(
774 775 776
            "3. run 'python tools/print_signatures.py paddle > paddle/fluid/API.spec'."
        )
        for temp in result:
777 778 779 780 781
            if not temp[0]:
                logger.info("In addition, mistakes found in sample codes: %s",
                            temp[1])
                logger.info("error_methods: %s", str(temp[2]))
        logger.info("----------------------------------------------------")
782 783
        exit(1)
    else:
784
        has_error = False
785
        for temp in result:
786 787 788 789 790 791 792 793 794 795
            if not temp[0]:
                logger.info("In addition, mistakes found in sample codes: %s",
                            temp[1])
                logger.info("error_methods: %s", str(temp[2]))
                has_error = True
        if has_error:
            logger.info("Mistakes found in sample codes.")
            logger.info("Please check sample codes.")
            exit(1)
    logger.info("Sample code check is successful!")