sampcd_processor.py 29.1 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
T
tianshuo78520a 已提交
14 15

import os
16
import sys
T
tianshuo78520a 已提交
17
import subprocess
18 19 20
import multiprocessing
import math
import platform
21
import inspect
22 23
#import paddle
#import paddle.fluid
Z
zhangchunle 已提交
24
import json
25 26 27 28
import argparse
import shutil
import re
import logging
29 30
"""
please make sure to run in the tools path
31
usage: python sample_test.py {arg1} 
32 33 34
arg1: the first arg defined running in gpu version or cpu version

for example, you can run cpu version python2 testing like this:
35 36 37

    python sampcd_processor.py cpu 

38
"""
T
tianshuo78520a 已提交
39

40 41 42 43 44 45 46
logger = logging.getLogger()
if logger.handlers:
    console = logger.handlers[
        0]  # we assume the first handler is the one we want to configure
else:
    console = logging.StreamHandler()
    logger.addHandler(console)
47
console.setFormatter(logging.Formatter("%(message)s"))
48 49 50 51 52 53 54 55 56 57

RUN_ON_DEVICE = 'cpu'
GPU_ID = 0
methods = []
whl_error = []
API_DEV_SPEC_FN = 'paddle/fluid/API_DEV.spec'
API_PR_SPEC_FN = 'paddle/fluid/API_PR.spec'
API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec'
SAMPLECODE_TEMPDIR = 'samplecode_temp'

T
tianshuo78520a 已提交
58 59

def find_all(srcstr, substr):
60
    """
61 62 63 64 65 66
    to find all desired substring in the source string
     and return their starting indices as a list

    Args:
        srcstr(str): the parent string
        substr(str): substr
67

68
    Returns:
69
        list: a list of the indices of the substrings
70
              found
71
    """
T
tianshuo78520a 已提交
72 73 74 75 76 77 78 79 80
    indices = []
    gotone = srcstr.find(substr)
    while (gotone != -1):
        indices.append(gotone)
        gotone = srcstr.find(substr, gotone + 1)
    return indices


def check_indent(cdline):
81
    """
82
    to check the indent of a given code line
83

84 85
    to get the number of starting blank chars,
    e.t. blankspaces and \t
86 87

    \t will be interpreted as 4 single blankspaces,
88
    e.t. '\t'='    '
89

90 91 92 93
    Args:
        cdline(str) : a single line of code from the source file

    Returns:
94
        int : the indent of the number of interpreted
95
             blankspaces
96
    """
T
tianshuo78520a 已提交
97 98 99 100 101 102 103 104 105 106 107
    indent = 0
    for c in cdline:
        if c == '\t':
            indent += 4
        elif c == ' ':
            indent += 1
        if c != ' ' and c != '\t':
            break
    return indent


108 109 110
# srccom: raw comments in the source,including ''' and original indent
def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
    """
111 112 113 114 115 116 117 118 119
    Extract and run sample codes from source comment and
    the result will be returned.

    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
120

121
    Returns:
122
        result: True or False
123 124
        name(str): the name of the API.
        msg(str): messages
125
    """
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    sample_code_filenames = sampcd_extract_to_file(srccom, name, htype, hname)
    if not sample_code_filenames:
        return False, name, 'No sample code!'

    results = []
    msgs = []
    for tfname in sample_code_filenames:
        result, msg = execute_samplecode_test(tfname)
        results.append(result)
        msgs.append(msg)

    if not all(results):
        failed_fn = []
        for i, result in enumerate(results):
            if not result:
                failed_fn.append(sample_code_filenames[i])
        return False, name, 'failed sample codes: ' + ','.join(failed_fn)
    return True, name, 'success!'


def sampcd_extract_to_file(srccom, name, htype="def", hname=""):
    """
    Extract sample codes from __doc__, and write them to files.
149

150 151 152 153 154 155
    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
156

157 158 159 160
    Returns:
        sample_code_filenames(list of str)
    """
    global GPU_ID, RUN_ON_DEVICE, SAMPLECODE_TEMPDIR
161 162

    sampcd_begins = find_all(srccom, " code-block:: python")
163
    if len(sampcd_begins) == 0:
164 165 166
        # detect sample codes using >>> to format and consider this situation as wrong
        print(htype, " name:", hname)
        print("-----------------------")
167 168 169
        if srccom.find("Examples:") != -1:
            print("----example code check----\n")
            if srccom.find(">>>") != -1:
T
tianshuo78520a 已提交
170
                print(
171 172
                    "Deprecated sample code style:\n\n    Examples:\n\n        >>>codeline\n        >>>codeline\n\n\n ",
                    "Please use '.. code-block:: python' to ",
T
tianshuo78520a 已提交
173
                    "format sample code.\n")
174
                return []
T
tianshuo78520a 已提交
175
        else:
176
            print("Error: No sample code!\n")
177 178
            return []
    sample_code_filenames = []
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    for y in range(1, len(sampcd_begins) + 1):
        sampcd_begin = sampcd_begins[y - 1]
        sampcd = srccom[sampcd_begin + len(" code-block:: python") + 1:]
        sampcd = sampcd.split("\n")
        # remove starting empty lines
        while sampcd[0].replace(' ', '').replace('\t', '') == '':
            sampcd.pop(0)

        # the minimum indent, which is the indent of the first
        # non-empty line
        min_indent = check_indent(sampcd[0])
        sampcd_to_write = []
        for i in range(0, len(sampcd)):
            cdline = sampcd[i]
            # handle empty lines or those only with spaces/tabs
            if cdline.strip() == '':
                continue
            this_indent = check_indent(cdline)
            if this_indent < min_indent:
                break
Z
zhangchunle 已提交
199
            else:
200 201 202 203
                cdline = cdline.replace('\t', '    ')
                sampcd_to_write.append(cdline[min_indent:])

        sampcd = '\n'.join(sampcd_to_write)
204 205 206 207 208
        if RUN_ON_DEVICE == "cpu":
            sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd
        if RUN_ON_DEVICE == "gpu":
            sampcd = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = "{}"\n'.format(
                GPU_ID) + sampcd
209 210
        sampcd += '\nprint(' + '\"' + name + ' sample code is executed successfully!\")'

211 212 213 214
        tfname = os.path.join(SAMPLECODE_TEMPDIR, '{}_example{}'.format(
            name, '.py' if len(sampcd_begins) == 1 else '_{}.py'.format(y)))
        with open(tfname, 'w') as tempf:
            tempf.write(sampcd)
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
        sample_code_filenames.append(tfname)
    return sample_code_filenames


def execute_samplecode_test(tfname):
    result = True
    msg = None
    if platform.python_version()[0] in ["2", "3"]:
        cmd = [sys.executable, tfname]
    else:
        print("Error: fail to parse python version!")
        result = False
        exit(1)

    logging.info('running %s', tfname)
    print("\n----example code check----")
    print("executing sample code .....", tfname)
    subprc = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, error = subprc.communicate()
    msg = "".join(output.decode(encoding='utf-8'))
    err = "".join(error.decode(encoding='utf-8'))

    if subprc.returncode != 0:
        print("Sample code error found in ", tfname, ":")
        print("-----------------------")
        print(open(tfname).read())
        print("-----------------------")
        print("subprocess return code: ", str(subprc.returncode))
        print("Error Raised from Sample Code ", tfname, " :")
        print(err)
        print(msg)
        print("----example code check failed----\n")
        logging.warning('%s error: %s', tfname, err)
        logging.warning('%s msg: %s', tfname, msg)
        result = False
    else:
        print("----example code check success----\n")

    # msg is the returned code execution report

    return result, tfname, msg
T
tianshuo78520a 已提交
257 258 259


def single_defcom_extract(start_from, srcls, is_class_begin=False):
260
    """
261 262
    to extract a def function/class/method comments body

263
    Args:
264 265 266 267 268 269
        start_from(int): the line num of "def" header
        srcls(list): the source file in lines
        is_class_begin(bool): whether the start_from is a beginning a class. \
        For a sole class body itself may end up with its method if it has no
        docstring. But the body of \
        a common def function can only be ended up by a none-indented def/class
270

271 272 273
    Returns:
        string : the extracted comment body, inclusive of its quote marks.

274
    """
275

T
tianshuo78520a 已提交
276
    i = start_from
277 278 279
    fcombody = ""  # def comment body
    comstart = -1  # the starting line index of comment mark "'''" or """"""
    # if it is not -1, it indicates the loop is in the comment body
280 281
    comstyle = 0  # comment mark style ,comments quoted with ''' is coded as 1
    # comments quoted with """ is coded as 2
T
tianshuo78520a 已提交
282 283
    for x in range(i + 1, len(srcls)):
        if is_class_begin:
284
            if srcls[x].replace('\t', '    ').startswith('    def '):
T
tianshuo78520a 已提交
285
                break
286
        if srcls[x].startswith('def ') or srcls[x].startswith('class '):
T
tianshuo78520a 已提交
287 288
            break
        else:
289 290 291 292 293 294 295
            if comstart == -1:
                s = srcls[x].replace(" ", '').replace("\t",
                                                      '').replace("\n", '')
                if s.startswith("\"\"\"") or s.startswith("r\"\"\""):
                    comstart = x
                    comstyle = 2
                    continue
T
tianshuo78520a 已提交
296 297 298 299
            if (comstyle == 2 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\"\"\"")):
                break
300 301 302 303 304 305 306
            if comstart == -1:
                s = srcls[x].replace(" ", '').replace("\t",
                                                      '').replace("\n", '')
                if s.startswith("\'\'\'") or s.startswith("r\'\'\'"):
                    comstart = x
                    comstyle = 1
                    continue
T
tianshuo78520a 已提交
307 308 309 310 311
            if (comstyle == 1 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\'\'\'")):
                break
            if (comstart !=
312
                    -1):  # when the comments start, begin to add line to fcombody
T
tianshuo78520a 已提交
313 314 315 316
                fcombody += srcls[x]
    return fcombody


317
def srccoms_extract(srcfile, wlist, methods):
318
    """
319 320 321 322 323 324 325
    Given a source file ``srcfile``, this function will
    extract its API(doc comments) and run sample codes in the
    API.

    Args:
        srcfile(file): the source file
        wlist(list): white list
326
        methods(list): only elements of this list considered.
327 328

    Returns:
329
        result: True or False
330
        error_methods: the methods that failed.
331
    """
332

333
    process_result = True
334
    error_methods = []
T
tianshuo78520a 已提交
335
    srcc = srcfile.read()
336 337
    # 2. get defs and classes header line number
    # set file pointer to its beginning
T
tianshuo78520a 已提交
338
    srcfile.seek(0, 0)
339
    srcls = srcfile.readlines()  # source lines
340

341
    # 1. fetch__all__ list
T
tianshuo78520a 已提交
342
    allidx = srcc.find("__all__")
343 344
    logger.debug('processing %s, methods: %s', srcfile.name, str(methods))
    srcfile_new, _ = os.path.splitext(srcfile.name)
345 346 347 348
    srcfile_list = srcfile_new.split('/')
    srcfile_str = ''
    for i in range(4, len(srcfile_list)):
        srcfile_str = srcfile_str + srcfile_list[i] + '.'
349
    if allidx != -1:
T
tianshuo78520a 已提交
350
        alllist = []
351
        # get all list for layers/ops.py
352
        if srcfile.name.find("fluid/layers/ops.py") != -1:
T
tianshuo78520a 已提交
353
            for ai in range(0, len(srcls)):
354
                if srcls[ai].startswith("__all__"):
T
tianshuo78520a 已提交
355 356
                    lb = srcls[ai].find('[')
                    rb = srcls[ai].find(']')
357
                    if lb == -1:
T
tianshuo78520a 已提交
358 359 360 361
                        continue
                    allele = srcls[ai][lb + 1:rb].replace("'", '').replace(
                        " ", '').replace("\"", '')
                    alllist.append(allele)
362 363
            if '' in alllist:
                alllist.remove('')
T
tianshuo78520a 已提交
364 365 366 367 368 369 370 371 372 373
        else:
            alllist_b = allidx + len("__all__")
            allstr = srcc[alllist_b + srcc[alllist_b:].find("[") + 1:alllist_b +
                          srcc[alllist_b:].find("]")]
            allstr = allstr.replace("\n", '').replace(" ", '').replace(
                "'", '').replace("\"", '')
            alllist = allstr.split(',')
            if '' in alllist:
                alllist.remove('')
        api_alllist_count = len(alllist)
374
        logger.debug('found %d items: %s', api_alllist_count, str(alllist))
T
tianshuo78520a 已提交
375 376
        api_count = 0
        handled = []
377
        # get src contents in layers/ops.py
378
        if srcfile.name.find("fluid/layers/ops.py") != -1:
T
tianshuo78520a 已提交
379
            for i in range(0, len(srcls)):
380 381 382 383 384 385 386 387 388 389
                opname = None
                opres = re.match(r"^(\w+)\.__doc__", srcls[i])
                if opres is not None:
                    opname = opres.group(1)
                else:
                    opres = re.match(
                        r"^add_sample_code\(globals\(\)\[\"(\w+)\"\]", srcls[i])
                    if opres is not None:
                        opname = opres.group(1)
                if opname is not None:
T
tianshuo78520a 已提交
390
                    if opname in wlist:
391
                        logger.info('%s is in the whitelist, skip it.', opname)
T
tianshuo78520a 已提交
392
                        continue
393 394
                    else:
                        logger.debug('%s\'s docstring found.', opname)
T
tianshuo78520a 已提交
395 396
                    comstart = i
                    for j in range(i, len(srcls)):
397
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
398 399 400 401
                            comstart = i
                    opcom = ""
                    for j in range(comstart + 1, len(srcls)):
                        opcom += srcls[j]
402
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
403
                            break
404 405 406 407 408
                    result, _, _ = sampcd_extract_and_run(opcom, opname, "def",
                                                          opname)
                    if not result:
                        error_methods.append(opname)
                        process_result = False
T
tianshuo78520a 已提交
409
                    api_count += 1
410
                    handled.append(
411 412 413
                        opname)  # ops.py also has normal formatted functions
                    # use list 'handled'  to mark the functions have been handled here
                    # which will be ignored in the following step
414 415
                    # handled what?
        logger.debug('%s already handled.', str(handled))
T
tianshuo78520a 已提交
416
        for i in range(0, len(srcls)):
417
            if srcls[i].startswith(
418
                    'def '):  # a function header is detected in line i
T
tianshuo78520a 已提交
419
                f_header = srcls[i].replace(" ", '')
420
                fn = f_header[len('def'):f_header.find('(')]  # function name
421
                if "%s%s" % (srcfile_str, fn) not in methods:
422 423 424
                    logger.info(
                        '[file:%s, function:%s] not in methods list, skip it.',
                        srcfile_str, fn)
425
                    continue
T
tianshuo78520a 已提交
426 427 428 429
                if fn in handled:
                    continue
                if fn in alllist:
                    api_count += 1
430
                    if fn in wlist or fn + "@" + srcfile.name in wlist:
431 432
                        logger.info('[file:%s, function:%s] skip by wlist.',
                                    srcfile_str, fn)
T
tianshuo78520a 已提交
433 434
                        continue
                    fcombody = single_defcom_extract(i, srcls)
435
                    if fcombody == "":  # if no comment
436 437
                        print("def name:", fn)
                        print("-----------------------")
438 439
                        print("WARNING: no comments in function ", fn,
                              ", but it deserves.")
T
tianshuo78520a 已提交
440 441
                        continue
                    else:
442 443 444 445
                        result, _, _ = sampcd_extract_and_run(fcombody, fn,
                                                              "def", fn)
                        if not result:
                            error_methods.append(fn)
446
                            process_result = False
447

T
tianshuo78520a 已提交
448 449
            if srcls[i].startswith('class '):
                c_header = srcls[i].replace(" ", '')
450
                cn = c_header[len('class'):c_header.find('(')]  # class name
451
                if '%s%s' % (srcfile_str, cn) not in methods:
452 453 454
                    logger.info(
                        '[file:%s, class:%s] not in methods list, skip it.',
                        srcfile_str, cn)
455
                    continue
T
tianshuo78520a 已提交
456 457 458 459
                if cn in handled:
                    continue
                if cn in alllist:
                    api_count += 1
460
                    if cn in wlist or cn + "@" + srcfile.name in wlist:
461 462
                        logger.info('[file:%s, class:%s] skip by wlist.',
                                    srcfile_str, cn)
T
tianshuo78520a 已提交
463
                        continue
464
                    # class comment
T
tianshuo78520a 已提交
465
                    classcom = single_defcom_extract(i, srcls, True)
466
                    if classcom != "":
467 468 469 470
                        result, _, _ = sampcd_extract_and_run(classcom, cn,
                                                              "class", cn)
                        if not result:
                            error_methods.append(cn)
471
                            process_result = False
T
tianshuo78520a 已提交
472
                    else:
473 474 475
                        print("WARNING: no comments in class itself ", cn,
                              ", but it deserves.\n")
                    # handling methods in class bodies
T
tianshuo78520a 已提交
476 477
                    for x in range(
                            i + 1,
478
                            len(srcls)):  # from the next line of class header
T
tianshuo78520a 已提交
479 480 481 482
                        if (srcls[x].startswith('def ') or
                                srcls[x].startswith('class ')):
                            break
                        else:
483
                            # member method def header
484
                            srcls[x] = srcls[x].replace('\t', '    ')
T
tianshuo78520a 已提交
485
                            if (srcls[x].startswith(
486
                                    '    def ')):  # detect a mehtod header..
T
tianshuo78520a 已提交
487 488 489
                                thisl = srcls[x]
                                indent = len(thisl) - len(thisl.lstrip())
                                mn = thisl[indent + len('def '):thisl.find(
490 491
                                    '(')]  # method name
                                name = cn + "." + mn  # full name
492 493 494
                                if '%s%s' % (
                                        srcfile_str, name
                                ) not in methods:  # class method not in api.spec 
495 496 497
                                    logger.info(
                                        '[file:%s, func:%s] not in methods, skip it.',
                                        srcfile_str, name)
498
                                    continue
T
tianshuo78520a 已提交
499
                                if mn.startswith('_'):
500 501 502
                                    logger.info(
                                        '[file:%s, func:%s] startswith _, it\'s private method, skip it.',
                                        srcfile_str, name)
T
tianshuo78520a 已提交
503
                                    continue
504
                                if name in wlist or name + "@" + srcfile.name in wlist:
505 506 507
                                    logger.info(
                                        '[file:%s, class:%s] skip by wlist.',
                                        srcfile_str, name)
T
tianshuo78520a 已提交
508
                                    continue
509 510 511 512 513
                                thismethod = [thisl[indent:]
                                              ]  # method body lines
                                # get all the lines of a single method body
                                # into thismethod(list)
                                # and send it to single_defcom_extract
T
tianshuo78520a 已提交
514
                                for y in range(x + 1, len(srcls)):
515
                                    srcls[y] = srcls[y].replace('\t', '    ')
T
tianshuo78520a 已提交
516 517
                                    if (srcls[y].startswith('def ') or
                                            srcls[y].startswith('class ')):
518
                                        # end of method
T
tianshuo78520a 已提交
519
                                        break
520 521
                                    elif srcls[y].startswith('    def '):
                                        # end of method
T
tianshuo78520a 已提交
522 523 524 525 526
                                        break
                                    else:
                                        thismethod.append(srcls[y][indent:])
                                thismtdcom = single_defcom_extract(0,
                                                                   thismethod)
527
                                if thismtdcom != "":
528 529 530 531
                                    result, _, _ = sampcd_extract_and_run(
                                        thismtdcom, name, "method", name)
                                    if not result:
                                        error_methods.append(name)
532
                                        process_result = False
533 534
    else:
        logger.warning('__all__ not found in file:%s', srcfile.name)
535

536
    return process_result, error_methods
T
tianshuo78520a 已提交
537 538


539
def test(file_list):
540
    global methods  # readonly
541
    process_result = True
542
    for file in file_list:
543
        with open(file, 'r') as src:
544
            if not srccoms_extract(src, wlist, methods):
545 546
                process_result = False
    return process_result
547 548


549 550 551 552 553 554 555 556 557 558 559
def run_a_test(tc_filename):
    """
    execute a sample code-block.
    """
    global methods  # readonly
    process_result = True
    with open(tc_filename, 'r') as src:
        process_result, error_methods = srccoms_extract(src, wlist, methods)
    return process_result, tc_filename, error_methods


560
def get_filenames():
561
    '''
562
    this function will get the sample code files that pending for check.
563 564 565

    Returns:

566
        dict: the sample code files pending for check .
567 568

    '''
569
    global methods  # write
570
    global whl_error
571
    import paddle
572
    whl_error = []
573
    get_incrementapi()
574 575
    all_sample_code_filenames = {}
    with open(API_DIFF_SPEC_FN) as f:
576
        for line in f.readlines():
577
            api = line.replace('\n', '')
578
            try:
579
                api_obj = eval(api)
580
            except AttributeError:
581
                whl_error.append(api)
582
                continue
583 584 585 586
            except SyntaxError:
                logger.warning('line:%s, api:%s', line, api)
                # paddle.Tensor.<lambda>
                continue
587 588 589 590 591 592
            if hasattr(api_obj, '__doc__') and api_obj.__doc__:
                sample_code_filenames = sampcd_extract_to_file(api_obj.__doc__,
                                                               api)
                for tfname in sample_code_filenames:
                    all_sample_code_filenames[tfname] = api
    return all_sample_code_filenames
593 594


595 596 597 598
def get_api_md5(path):
    api_md5 = {}
    API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")),
                          path)
599 600 601
    pat = re.compile(r'\((paddle[^,]+)\W*document\W*([0-9a-z]{32})')
    patArgSpec = re.compile(
        r'^(paddle[^,]+)\s+\(ArgSpec.*document\W*([0-9a-z]{32})')
602 603
    with open(API_spec) as f:
        for line in f.readlines():
604 605 606 607 608
            mo = pat.search(line)
            if not mo:
                mo = patArgSpec.search(line)
            if mo:
                api_md5[mo.group(1)] = mo.group(2)
609 610 611
    return api_md5


612 613 614 615
def get_incrementapi():
    '''
    this function will get the apis that difference between API_DEV.spec and API_PR.spec.
    '''
616 617 618 619
    global API_DEV_SPEC_FN, API_PR_SPEC_FN, API_DIFF_SPEC_FN  ## readonly
    dev_api = get_api_md5(API_DEV_SPEC_FN)
    pr_api = get_api_md5(API_PR_SPEC_FN)
    with open(API_DIFF_SPEC_FN, 'w') as f:
620 621 622
        for key in pr_api:
            if key in dev_api:
                if dev_api[key] != pr_api[key]:
623 624
                    logger.debug("%s in dev is %s, different from pr's %s", key,
                                 dev_api[key], pr_api[key])
625 626 627
                    f.write(key)
                    f.write('\n')
            else:
628
                logger.debug("%s is not in dev", key)
629 630 631 632
                f.write(key)
                f.write('\n')


633
def get_wlist(fn="wlist.json"):
Z
zhangchunle 已提交
634 635 636 637 638 639 640 641 642
    '''
    this function will get the white list of API.

    Returns:

        wlist: a list of API that should not trigger the example check .

    '''
    wlist = []
Z
zhangchunle 已提交
643
    wlist_file = []
644 645
    # only white on CPU
    gpu_not_white = []
646
    with open(fn, 'r') as load_f:
Z
zhangchunle 已提交
647 648
        load_dict = json.load(load_f)
        for key in load_dict:
649 650 651 652 653 654 655 656
            if key == 'wlist_dir':
                for item in load_dict[key]:
                    wlist_file.append(item["name"])
            elif key == "gpu_not_white":
                gpu_not_white = load_dict[key]
            elif key == "wlist_api":
                for item in load_dict[key]:
                    wlist.append(item["name"])
Z
zhangchunle 已提交
657 658
            else:
                wlist = wlist + load_dict[key]
659
    return wlist, wlist_file, gpu_not_white
Z
zhangchunle 已提交
660 661


662 663 664 665 666 667
arguments = [
    # flags, dest, type, default, help
    ['--gpu_id', 'gpu_id', int, 0, 'GPU device id to use [0]'],
    ['--logf', 'logf', str, None, 'file for logging'],
    ['--threads', 'threads', int, 0, 'sub processes number'],
]
668

669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711

def parse_args():
    """
    Parse input arguments
    """
    global arguments
    parser = argparse.ArgumentParser(description='run Sample Code Test')
    # parser.add_argument('--cpu', dest='cpu_mode', action="store_true",
    #                     help='Use CPU mode (overrides --gpu)')
    # parser.add_argument('--gpu', dest='gpu_mode', action="store_true")
    parser.add_argument('--debug', dest='debug', action="store_true")
    parser.add_argument('mode', type=str, help='run on device', default='cpu')
    for item in arguments:
        parser.add_argument(
            item[0], dest=item[1], help=item[4], type=item[2], default=item[3])

    if len(sys.argv) == 1:
        args = parser.parse_args(['cpu'])
        return args
    #    parser.print_help()
    #    sys.exit(1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    if args.debug:
        logger.setLevel(logging.DEBUG)
    if args.logf:
        logfHandler = logging.FileHandler(args.logf)
        logfHandler.setFormatter(
            logging.Formatter(
                "%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"
            ))
        logger.addHandler(logfHandler)

    wlist, wlist_file, gpu_not_white = get_wlist()

    if args.mode == "gpu":
        GPU_ID = args.gpu_id
        logger.info("using GPU_ID %d", GPU_ID)
712 713
        for _gnw in gpu_not_white:
            wlist.remove(_gnw)
714 715 716
    elif args.mode != "cpu":
        logger.error("Unrecognized argument:%s, 'cpu' or 'gpu' is desired.",
                     args.mode)
717
        sys.exit("Invalid arguments")
718 719 720 721 722 723 724 725 726 727 728 729
    RUN_ON_DEVICE = args.mode
    logger.info("API check -- Example Code")
    logger.info("sample_test running under python %s",
                platform.python_version())

    if os.path.exists(SAMPLECODE_TEMPDIR):
        if not os.path.isdir(SAMPLECODE_TEMPDIR):
            os.remove(SAMPLECODE_TEMPDIR)
            os.mkdir(SAMPLECODE_TEMPDIR)
    else:
        os.mkdir(SAMPLECODE_TEMPDIR)

730
    filenames = get_filenames()
731
    if len(filenames) == 0 and len(whl_error) == 0:
732
        logger.info("-----API_PR.spec is the same as API_DEV.spec-----")
733
        exit(0)
734 735 736 737 738 739 740
    logger.info("API_PR is diff from API_DEV: %s", filenames)

    threads = multiprocessing.cpu_count()
    if args.threads:
        threads = args.threads
    po = multiprocessing.Pool(threads)
    # results = po.map_async(test, divided_file_list)
741
    results = po.map_async(execute_samplecode_test, filenames.keys())
742 743
    po.close()
    po.join()
744

745
    result = results.get()
746

747
    # delete temp files
748 749
    if not args.debug:
        shutil.rmtree(SAMPLECODE_TEMPDIR)
750

751
    logger.info("----------------End of the Check--------------------")
752
    if len(whl_error) != 0:
753 754 755 756 757 758 759
        logger.info("%s is not in whl.", whl_error)
        logger.info("")
        logger.info("Please check the whl package and API_PR.spec!")
        logger.info("You can follow these steps in order to generate API.spec:")
        logger.info("1. cd ${paddle_path}, compile paddle;")
        logger.info("2. pip install build/python/dist/(build whl package);")
        logger.info(
760 761 762
            "3. run 'python tools/print_signatures.py paddle > paddle/fluid/API.spec'."
        )
        for temp in result:
763 764 765 766 767
            if not temp[0]:
                logger.info("In addition, mistakes found in sample codes: %s",
                            temp[1])
                logger.info("error_methods: %s", str(temp[2]))
        logger.info("----------------------------------------------------")
768 769
        exit(1)
    else:
770
        has_error = False
771
        for temp in result:
772 773 774 775 776 777 778 779 780 781
            if not temp[0]:
                logger.info("In addition, mistakes found in sample codes: %s",
                            temp[1])
                logger.info("error_methods: %s", str(temp[2]))
                has_error = True
        if has_error:
            logger.info("Mistakes found in sample codes.")
            logger.info("Please check sample codes.")
            exit(1)
    logger.info("Sample code check is successful!")