sampcd_processor.py 25.8 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15
"""
please make sure to run in the tools path
16
usage: python sample_test.py {cpu or gpu}
17 18
    {cpu or gpu}: running in cpu version or gpu version

19
for example, you can run cpu version testing like this:
20

21
    python sampcd_processor.py cpu
T
tianshuo78520a 已提交
22

23
"""
24 25 26
import argparse
import inspect
import logging
27
import multiprocessing
28
import os
29
import platform
30
import re
31 32 33
import shutil
import subprocess
import sys
34
import time
T
tianshuo78520a 已提交
35

36 37 38
logger = logging.getLogger()
if logger.handlers:
    console = logger.handlers[
39 40
        0
    ]  # we assume the first handler is the one we want to configure
41
else:
42
    console = logging.StreamHandler(stream=sys.stderr)
43
    logger.addHandler(console)
44
console.setFormatter(logging.Formatter("%(message)s"))
45 46

RUN_ON_DEVICE = 'cpu'
47
SAMPLE_CODE_TEST_CAPACITY = set()
48 49 50 51 52 53
GPU_ID = 0
whl_error = []
API_DEV_SPEC_FN = 'paddle/fluid/API_DEV.spec'
API_PR_SPEC_FN = 'paddle/fluid/API_PR.spec'
API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec'
SAMPLECODE_TEMPDIR = 'samplecode_temp'
54 55 56 57 58 59 60 61 62
ENV_KEY_CODES_FRONTEND = 'CODES_INSERTED_INTO_FRONTEND'
ENV_KEY_TEST_CAPACITY = 'SAMPLE_CODE_TEST_CAPACITY'
SUMMARY_INFO = {
    'success': [],
    'failed': [],
    'skiptest': [],
    'nocodes': [],
    # ... required not-match
}
63

T
tianshuo78520a 已提交
64 65

def find_all(srcstr, substr):
66
    """
67 68 69 70 71 72
    to find all desired substring in the source string
     and return their starting indices as a list

    Args:
        srcstr(str): the parent string
        substr(str): substr
73

74
    Returns:
75
        list: a list of the indices of the substrings
76
              found
77
    """
T
tianshuo78520a 已提交
78 79
    indices = []
    gotone = srcstr.find(substr)
80
    while gotone != -1:
T
tianshuo78520a 已提交
81 82 83 84 85
        indices.append(gotone)
        gotone = srcstr.find(substr, gotone + 1)
    return indices


86 87 88 89 90 91 92 93
def find_last_future_line_end(cbstr):
    """
    find the last `__future__` line.

    Args:
        docstr(str): docstring
    Return:
        index of the line end or None.
94
    """
95 96 97 98 99 100 101 102 103 104 105 106
    pat = re.compile('__future__.*\n')
    lastmo = None
    it = re.finditer(pat, cbstr)
    while True:
        try:
            lastmo = next(it)
        except StopIteration:
            break
    if lastmo:
        return lastmo.end()
    else:
        return None
107 108


109 110 111 112 113 114
def extract_code_blocks_from_docstr(docstr):
    """
    extract code-blocks from the given docstring.

    DON'T include the multiline-string definition in code-blocks.
    The *Examples* section must be the last.
115

116
    Args:
117 118
        docstr(str): docstring
    Return:
119
        code_blocks: A list of code-blocks, indent removed.
120 121 122 123 124 125 126 127
                     element {'name': the code-block's name, 'id': sequence id.
                              'codes': codes, 'required': 'gpu'}
    """
    code_blocks = []

    mo = re.search(r"Examples:", docstr)
    if mo is None:
        return code_blocks
128
    ds_list = docstr[mo.start() :].replace("\t", '    ').split("\n")
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    lastlineindex = len(ds_list) - 1

    cb_start_pat = re.compile(r"code-block::\s*python")
    cb_param_pat = re.compile(r"^\s*:(\w+):\s*(\S*)\s*$")
    cb_required_pat = re.compile(r"^\s*#\s*require[s|d]\s*:\s*(\S+)\s*$")

    cb_info = {}
    cb_info['cb_started'] = False
    cb_info['cb_cur'] = []
    cb_info['cb_cur_indent'] = -1
    cb_info['cb_cur_name'] = None
    cb_info['cb_cur_seq_id'] = 0
    cb_info['cb_required'] = None

    def _cb_started():
        # nonlocal cb_started, cb_cur_name, cb_required, cb_cur_seq_id
        cb_info['cb_started'] = True
        cb_info['cb_cur_seq_id'] += 1
        cb_info['cb_cur_name'] = None
        cb_info['cb_required'] = None

    def _append_code_block():
        # nonlocal code_blocks, cb_cur, cb_cur_name, cb_cur_seq_id, cb_required
152 153
        code_blocks.append(
            {
154
                'codes': inspect.cleandoc("\n" + "\n".join(cb_info['cb_cur'])),
155 156 157 158 159
                'name': cb_info['cb_cur_name'],
                'id': cb_info['cb_cur_seq_id'],
                'required': cb_info['cb_required'],
            }
        )
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187

    for lineno, linecont in enumerate(ds_list):
        if re.search(cb_start_pat, linecont):
            if not cb_info['cb_started']:
                _cb_started()
                continue
            else:
                # cur block end
                if len(cb_info['cb_cur']):
                    _append_code_block()
                _cb_started()  # another block started
                cb_info['cb_cur_indent'] = -1
                cb_info['cb_cur'] = []
        else:
            if cb_info['cb_started']:
                # handle the code-block directive's options
                mo_p = cb_param_pat.match(linecont)
                if mo_p:
                    if mo_p.group(1) == 'name':
                        cb_info['cb_cur_name'] = mo_p.group(2)
                    continue
                # read the required directive
                mo_r = cb_required_pat.match(linecont)
                if mo_r:
                    cb_info['cb_required'] = mo_r.group(1)
                # docstring end
                if lineno == lastlineindex:
                    mo = re.search(r"\S", linecont)
188 189 190
                    if (
                        mo is not None
                        and cb_info['cb_cur_indent'] <= mo.start()
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
                    ):
                        cb_info['cb_cur'].append(linecont)
                    if len(cb_info['cb_cur']):
                        _append_code_block()
                    break
                # check indent for cur block start and end.
                mo = re.search(r"\S", linecont)
                if mo is None:
                    continue
                if cb_info['cb_cur_indent'] < 0:
                    # find the first non empty line
                    cb_info['cb_cur_indent'] = mo.start()
                    cb_info['cb_cur'].append(linecont)
                else:
                    if cb_info['cb_cur_indent'] <= mo.start():
                        cb_info['cb_cur'].append(linecont)
                    else:
                        if linecont[mo.start()] == '#':
                            continue
                        else:
                            # block end
                            if len(cb_info['cb_cur']):
                                _append_code_block()
                            cb_info['cb_started'] = False
                            cb_info['cb_cur_indent'] = -1
                            cb_info['cb_cur'] = []
    return code_blocks


def get_test_capacity():
    """
    collect capacities and set to SAMPLE_CODE_TEST_CAPACITY
    """
    global SAMPLE_CODE_TEST_CAPACITY  # write
    global ENV_KEY_TEST_CAPACITY, RUN_ON_DEVICE  # readonly
    if ENV_KEY_TEST_CAPACITY in os.environ:
        for r in os.environ[ENV_KEY_TEST_CAPACITY].split(','):
            rr = r.strip().lower()
            if r:
                SAMPLE_CODE_TEST_CAPACITY.add(rr)
    if 'cpu' not in SAMPLE_CODE_TEST_CAPACITY:
        SAMPLE_CODE_TEST_CAPACITY.add('cpu')
233

234 235 236 237 238
    if RUN_ON_DEVICE:
        SAMPLE_CODE_TEST_CAPACITY.add(RUN_ON_DEVICE)


def is_required_match(requirestr, cbtitle='not-specified'):
239
    """
240
    search the required instruction in the code-block, and check it match the current running environment.
241

242 243 244 245 246 247 248 249 250 251 252
    environment values of equipped: cpu, gpu, xpu, distributed, skip
    the 'skip' is the special flag to skip the test, so is_required_match will return False directly.

    Args:
        requirestr(str): the required string.
        cbtitle(str): the title of the code-block.
    returns:
        True - yes, matched
        False - not match
        None - skipped  # trick
    """
253
    global SAMPLE_CODE_TEST_CAPACITY, RUN_ON_DEVICE  # readonly
254
    requires = {'cpu'}
255 256 257 258 259
    if requirestr:
        for r in requirestr.split(','):
            rr = r.strip().lower()
            if rr:
                requires.add(rr)
260 261
    else:
        requires.add(RUN_ON_DEVICE)
262 263 264 265
    if 'skip' in requires or 'skiptest' in requires:
        logger.info('%s: skipped', cbtitle)
        return None

266
    if all(
267 268 269
        k in SAMPLE_CODE_TEST_CAPACITY
        for k in requires
        if k not in ['skip', 'skiptest']
270
    ):
271 272
        return True

273 274 275 276 277 278
    logger.info(
        '%s: the equipments [%s] not match the required [%s].',
        cbtitle,
        ','.join(SAMPLE_CODE_TEST_CAPACITY),
        ','.join(requires),
    )
279 280 281 282 283 284 285 286 287 288
    return False


def insert_codes_into_codeblock(codeblock, apiname='not-specified'):
    """
    insert some codes in the frontend and backend into the code-block.
    """
    global ENV_KEY_CODES_FRONTEND, GPU_ID, RUN_ON_DEVICE  # readonly
    inserted_codes_f = ''
    inserted_codes_b = ''
289 290 291 292
    if (
        ENV_KEY_CODES_FRONTEND in os.environ
        and os.environ[ENV_KEY_CODES_FRONTEND]
    ):
293 294 295
        inserted_codes_f = os.environ[ENV_KEY_CODES_FRONTEND]
    else:
        cpu_str = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = ""\n'
296 297 298 299 300
        gpu_str = (
            '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = "{}"\n'.format(
                GPU_ID
            )
        )
301 302
        if 'required' in codeblock and codeblock['required']:
            if codeblock['required'] == 'cpu':
303 304 305 306 307 308 309 310 311
                inserted_codes_f = cpu_str
            elif codeblock['required'] == 'gpu':
                inserted_codes_f = gpu_str
        else:
            if RUN_ON_DEVICE == "cpu":
                inserted_codes_f = cpu_str
            elif RUN_ON_DEVICE == "gpu":
                inserted_codes_f = gpu_str
    inserted_codes_b = '\nprint("{}\'s sample code (name:{}, id:{}) is executed successfully!")'.format(
312 313
        apiname, codeblock['name'], codeblock['id']
    )
314 315 316 317

    cb = codeblock['codes']
    last_future_line_end = find_last_future_line_end(cb)
    if last_future_line_end:
318 319 320 321 322 323
        return (
            cb[:last_future_line_end]
            + inserted_codes_f
            + cb[last_future_line_end:]
            + inserted_codes_b
        )
324 325
    else:
        return inserted_codes_f + cb + inserted_codes_b
T
tianshuo78520a 已提交
326 327


328 329 330
def sampcd_extract_to_file(srccom, name, htype="def", hname=""):
    """
    Extract sample codes from __doc__, and write them to files.
331

332 333 334 335 336 337
    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
338

339 340 341
    Returns:
        sample_code_filenames(list of str)
    """
342 343
    global GPU_ID, RUN_ON_DEVICE, SAMPLECODE_TEMPDIR  # readonly
    global SUMMARY_INFO  # update
344

345 346 347
    codeblocks = extract_code_blocks_from_docstr(srccom)
    if len(codeblocks) == 0:
        SUMMARY_INFO['nocodes'].append(name)
348
        # detect sample codes using >>> to format and consider this situation as wrong
349 350
        logger.info(htype + " name:" + name)
        logger.info("-----------------------")
351
        if srccom.find("Examples:") != -1:
352
            logger.info("----example code check----")
353
            if srccom.find(">>>") != -1:
354 355
                logger.warning(
                    r"""Deprecated sample code style:
356 357 358 359
    Examples:
        >>>codeline
        >>>codeline

360 361
Please use '.. code-block:: python' to format the sample code."""
                )
362
                return []
T
tianshuo78520a 已提交
363
        else:
364 365 366
            logger.error(
                "Error: No sample code found! Please check if the API comment contais string 'Examples:' correctly"
            )
367
            return []
368

369
    sample_code_filenames = []
370 371 372 373 374 375 376
    for y, cb in enumerate(codeblocks):
        matched = is_required_match(cb['required'], name)
        # matched has three states:
        # True - please execute it;
        # None - no sample code found;
        # False - it need other special equipment or environment.
        # so, the following conditional statements are intentionally arranged.
377
        if matched:
378
            tfname = os.path.join(
379 380
                SAMPLECODE_TEMPDIR,
                '{}_example{}'.format(
381
                    name,
382
                    '.py' if len(codeblocks) == 1 else f'_{y + 1}.py',
383 384
                ),
            )
385 386 387 388 389
            with open(tfname, 'w') as tempf:
                sampcd = insert_codes_into_codeblock(cb, name)
                tempf.write(sampcd)
            sample_code_filenames.append(tfname)
        elif matched is None:
390 391 392 393 394
            logger.info(
                '{}\' code block (name:{}, id:{}) is skipped.'.format(
                    name, cb['name'], cb['id']
                )
            )
395
            SUMMARY_INFO['skiptest'].append("{}-{}".format(name, cb['id']))
396
        elif not matched:
397
            logger.info(
398 399 400 401 402 403 404 405
                '{}\' code block (name:{}, id:{}) required({}) not match capacity({}).'.format(
                    name,
                    cb['name'],
                    cb['id'],
                    cb['required'],
                    SAMPLE_CODE_TEST_CAPACITY,
                )
            )
406 407 408 409
            if cb['required'] not in SUMMARY_INFO:
                SUMMARY_INFO[cb['required']] = []
            SUMMARY_INFO[cb['required']].append("{}-{}".format(name, cb['id']))

410 411 412
    return sample_code_filenames


413 414
def execute_samplecode(tfname):
    """
415
    Execute a sample-code test
416 417

    Args:
418
        tfname: the filename of the sample code
419

420 421 422
    Returns:
        result: success or not
        tfname: same as the input argument
423 424
        msg: the stdout output of the sample code executing
        time: time consumed by sample code
425
    """
426 427
    result = True
    msg = None
T
tianshuo78520a 已提交
428
    if platform.python_version()[0] in ["3"]:
429 430
        cmd = [sys.executable, tfname]
    else:
431
        logger.error("Error: fail to parse python version!")
432
        result = False
433
        sys.exit(1)
434

435 436 437
    logger.info("----example code check----")
    logger.info("executing sample code: %s", tfname)
    start_time = time.time()
438 439 440
    subprc = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
441 442 443
    output, error = subprc.communicate()
    msg = "".join(output.decode(encoding='utf-8'))
    err = "".join(error.decode(encoding='utf-8'))
444
    end_time = time.time()
445 446

    if subprc.returncode != 0:
447
        with open(tfname, 'r') as f:
448 449
            logger.warning(
                """Sample code error found in %s:
450 451 452 453 454 455 456
-----------------------
%s
-----------------------
subprocess return code: %d
Error Raised from Sample Code:
stderr: %s
stdout: %s
457 458 459 460 461 462 463
""",
                tfname,
                f.read(),
                subprc.returncode,
                err,
                msg,
            )
464
        logger.info("----example code check failed----")
465 466
        result = False
    else:
467
        logger.info("----example code check success----")
468 469

    # msg is the returned code execution report
470
    return result, tfname, msg, end_time - start_time
T
tianshuo78520a 已提交
471 472


473
def get_filenames(full_test=False):
474
    '''
475
    this function will get the sample code files that pending for check.
476

477 478 479
    Args:
        full_test: the full apis or the increment

480 481
    Returns:

482
        dict: the sample code files pending for check .
483 484

    '''
485
    global whl_error
486
    import paddle  # noqa: F401
487
    import paddle.static.quantization  # noqa: F401
488

489
    whl_error = []
490
    if full_test:
491
        get_full_api_from_pr_spec()
492 493
    else:
        get_incrementapi()
494 495
    all_sample_code_filenames = {}
    with open(API_DIFF_SPEC_FN) as f:
496
        for line in f.readlines():
497
            api = line.replace('\n', '')
498
            try:
499
                api_obj = eval(api)
500
            except AttributeError:
501
                whl_error.append(api)
502
                continue
503 504 505 506
            except SyntaxError:
                logger.warning('line:%s, api:%s', line, api)
                # paddle.Tensor.<lambda>
                continue
507
            if hasattr(api_obj, '__doc__') and api_obj.__doc__:
508
                sample_code_filenames = sampcd_extract_to_file(
509 510
                    api_obj.__doc__, api
                )
511 512 513
                for tfname in sample_code_filenames:
                    all_sample_code_filenames[tfname] = api
    return all_sample_code_filenames
514 515


516
def get_api_md5(path):
517 518 519 520 521
    """
    read the api spec file, and scratch the md5sum value of every api's docstring.

    Args:
        path: the api spec file. ATTENTION the path relative
522

523 524 525
    Returns:
        api_md5(dict): key is the api's real fullname, value is the md5sum.
    """
526
    api_md5 = {}
527 528 529
    API_spec = os.path.abspath(os.path.join(os.getcwd(), "..", path))
    if not os.path.isfile(API_spec):
        return api_md5
530 531
    pat = re.compile(r'\((paddle[^,]+)\W*document\W*([0-9a-z]{32})')
    patArgSpec = re.compile(
532 533
        r'^(paddle[^,]+)\s+\(ArgSpec.*document\W*([0-9a-z]{32})'
    )
534 535
    with open(API_spec) as f:
        for line in f.readlines():
536 537 538 539 540
            mo = pat.search(line)
            if not mo:
                mo = patArgSpec.search(line)
            if mo:
                api_md5[mo.group(1)] = mo.group(2)
541 542 543
    return api_md5


544 545 546 547
def get_full_api():
    """
    get all the apis
    """
548
    global API_DIFF_SPEC_FN  # readonly
549
    from print_signatures import get_all_api_from_modulelist
550

551 552 553 554 555 556 557 558 559
    member_dict = get_all_api_from_modulelist()
    with open(API_DIFF_SPEC_FN, 'w') as f:
        f.write("\n".join(member_dict.keys()))


def get_full_api_by_walk():
    """
    get all the apis
    """
560
    global API_DIFF_SPEC_FN  # readonly
561
    from print_signatures import get_all_api
562

563 564
    apilist = get_all_api()
    with open(API_DIFF_SPEC_FN, 'w') as f:
565 566 567 568 569 570 571
        f.write("\n".join([ai[0] for ai in apilist]))


def get_full_api_from_pr_spec():
    """
    get all the apis
    """
572
    global API_PR_SPEC_FN, API_DIFF_SPEC_FN  # readonly
573 574 575 576 577 578
    pr_api = get_api_md5(API_PR_SPEC_FN)
    if len(pr_api):
        with open(API_DIFF_SPEC_FN, 'w') as f:
            f.write("\n".join(pr_api.keys()))
    else:
        get_full_api_by_walk()
579 580


581 582 583 584
def get_incrementapi():
    '''
    this function will get the apis that difference between API_DEV.spec and API_PR.spec.
    '''
585
    global API_DEV_SPEC_FN, API_PR_SPEC_FN, API_DIFF_SPEC_FN  # readonly
586 587 588
    dev_api = get_api_md5(API_DEV_SPEC_FN)
    pr_api = get_api_md5(API_PR_SPEC_FN)
    with open(API_DIFF_SPEC_FN, 'w') as f:
589 590 591
        for key in pr_api:
            if key in dev_api:
                if dev_api[key] != pr_api[key]:
592 593 594 595 596 597
                    logger.debug(
                        "%s in dev is %s, different from pr's %s",
                        key,
                        dev_api[key],
                        pr_api[key],
                    )
598 599 600
                    f.write(key)
                    f.write('\n')
            else:
601
                logger.debug("%s is not in dev", key)
602 603 604 605
                f.write(key)
                f.write('\n')


606 607 608 609 610
def exec_gen_doc():
    result = True
    cmd = ["bash", "document_preview.sh"]
    logger.info("----exec gen_doc----")
    start_time = time.time()
611 612 613
    subprc = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
    output, error = subprc.communicate()
    msg = "".join(output.decode(encoding='utf-8'))
    err = "".join(error.decode(encoding='utf-8'))
    end_time = time.time()

    if subprc.returncode != 0:
        logger.info("----gen_doc msg----")
        logger.info(msg)
        logger.error("----gen_doc error msg----")
        logger.error(err)
        logger.error("----exec gen_doc failed----")
        result = False
    else:
        logger.info("----gen_doc msg----")
        logger.info(msg)
        logger.info("----exec gen_doc success----")

    for fn in [
632 633
        '/docs/en/develop/index_en.html',
        '/docs/zh/develop/index_cn.html',
634 635 636 637 638 639 640 641 642 643
    ]:
        if os.path.exists(fn):
            logger.info('%s exists.', fn)
        else:
            logger.error('%s not exists.', fn)

    # msg is the returned code execution report
    return result, msg, end_time - start_time


644 645 646 647 648 649
arguments = [
    # flags, dest, type, default, help
    ['--gpu_id', 'gpu_id', int, 0, 'GPU device id to use [0]'],
    ['--logf', 'logf', str, None, 'file for logging'],
    ['--threads', 'threads', int, 0, 'sub processes number'],
]
650

651 652 653 654 655 656 657 658 659 660 661

def parse_args():
    """
    Parse input arguments
    """
    global arguments
    parser = argparse.ArgumentParser(description='run Sample Code Test')
    # parser.add_argument('--cpu', dest='cpu_mode', action="store_true",
    #                     help='Use CPU mode (overrides --gpu)')
    # parser.add_argument('--gpu', dest='gpu_mode', action="store_true")
    parser.add_argument('--debug', dest='debug', action="store_true")
662
    parser.add_argument('--full-test', dest='full_test', action="store_true")
663
    parser.add_argument('mode', type=str, help='run on device', default='cpu')
664 665 666 667 668 669
    parser.add_argument(
        '--build-doc',
        dest='build_doc',
        action='store_true',
        help='build doc if need.',
    )
670
    for item in arguments:
671 672 673
        parser.add_argument(
            item[0], dest=item[1], help=item[4], type=item[2], default=item[3]
        )
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688

    if len(sys.argv) == 1:
        args = parser.parse_args(['cpu'])
        return args
    #    parser.print_help()
    #    sys.exit(1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    if args.debug:
        logger.setLevel(logging.DEBUG)
689 690
    else:
        logger.setLevel(logging.INFO)
691 692 693 694 695
    if args.logf:
        logfHandler = logging.FileHandler(args.logf)
        logfHandler.setFormatter(
            logging.Formatter(
                "%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"
696 697
            )
        )
698 699 700 701 702 703
        logger.addHandler(logfHandler)

    if args.mode == "gpu":
        GPU_ID = args.gpu_id
        logger.info("using GPU_ID %d", GPU_ID)
    elif args.mode != "cpu":
704 705 706
        logger.error(
            "Unrecognized argument:%s, 'cpu' or 'gpu' is desired.", args.mode
        )
707
        sys.exit("Invalid arguments")
708
    RUN_ON_DEVICE = args.mode
709
    get_test_capacity()
710
    logger.info("API check -- Example Code")
711 712 713
    logger.info(
        "sample_test running under python %s", platform.python_version()
    )
714 715 716 717 718 719 720 721

    if os.path.exists(SAMPLECODE_TEMPDIR):
        if not os.path.isdir(SAMPLECODE_TEMPDIR):
            os.remove(SAMPLECODE_TEMPDIR)
            os.mkdir(SAMPLECODE_TEMPDIR)
    else:
        os.mkdir(SAMPLECODE_TEMPDIR)

722
    filenames = get_filenames(args.full_test)
723
    if len(filenames) == 0 and len(whl_error) == 0:
724
        logger.info("-----API_PR.spec is the same as API_DEV.spec-----")
725
        sys.exit(0)
726 727 728 729 730 731
    logger.info("API_PR is diff from API_DEV: %s", filenames)

    threads = multiprocessing.cpu_count()
    if args.threads:
        threads = args.threads
    po = multiprocessing.Pool(threads)
732
    results = po.map_async(execute_samplecode, filenames.keys())
733 734
    po.close()
    po.join()
735

736
    result = results.get()
737

738
    # delete temp files
739 740
    if not args.debug:
        shutil.rmtree(SAMPLECODE_TEMPDIR)
741

742 743
    stdout_handler = logging.StreamHandler(stream=sys.stdout)
    logger.addHandler(stdout_handler)
744
    logger.info("----------------End of the Check--------------------")
745
    if len(whl_error) != 0:
746 747 748 749 750 751 752
        logger.info("%s is not in whl.", whl_error)
        logger.info("")
        logger.info("Please check the whl package and API_PR.spec!")
        logger.info("You can follow these steps in order to generate API.spec:")
        logger.info("1. cd ${paddle_path}, compile paddle;")
        logger.info("2. pip install build/python/dist/(build whl package);")
        logger.info(
753 754 755
            "3. run 'python tools/print_signatures.py paddle > paddle/fluid/API.spec'."
        )
        for temp in result:
756
            if not temp[0]:
757 758 759
                logger.info(
                    "In addition, mistakes found in sample codes: %s", temp[1]
                )
760
        logger.info("----------------------------------------------------")
761
        sys.exit(1)
762
    else:
763
        timeovered_test = {}
764
        for temp in result:
765
            if not temp[0]:
766 767 768
                logger.info(
                    "In addition, mistakes found in sample codes: %s", temp[1]
                )
769 770 771 772 773 774 775
                SUMMARY_INFO['failed'].append(temp[1])
            else:
                SUMMARY_INFO['success'].append(temp[1])
            if temp[3] > 10:
                timeovered_test[temp[1]] = temp[3]

        if len(timeovered_test):
776 777 778
            logger.info(
                "%d sample codes ran time over 10s", len(timeovered_test)
            )
779 780
            if args.debug:
                for k, v in timeovered_test.items():
781
                    logger.info(f'{k} - {v}s')
782
        if len(SUMMARY_INFO['success']):
783 784 785
            logger.info(
                "%d sample codes ran success", len(SUMMARY_INFO['success'])
            )
786 787
        for k, v in SUMMARY_INFO.items():
            if k not in ['success', 'failed', 'skiptest', 'nocodes']:
788 789 790
                logger.info(
                    "%d sample codes required not match for %s", len(v), k
                )
791
        if len(SUMMARY_INFO['skiptest']):
792 793 794
            logger.info(
                "%d sample codes skipped", len(SUMMARY_INFO['skiptest'])
            )
795 796 797
            if args.debug:
                logger.info('\n'.join(SUMMARY_INFO['skiptest']))
        if len(SUMMARY_INFO['nocodes']):
798 799 800
            logger.info(
                "%d apis don't have sample codes", len(SUMMARY_INFO['nocodes'])
            )
801 802 803
            if args.debug:
                logger.info('\n'.join(SUMMARY_INFO['nocodes']))
        if len(SUMMARY_INFO['failed']):
804 805 806
            logger.info(
                "%d sample codes ran failed", len(SUMMARY_INFO['failed'])
            )
807 808 809 810
            logger.info('\n'.join(SUMMARY_INFO['failed']))
            logger.info(
                "Mistakes found in sample codes. Please recheck the sample codes."
            )
811
            sys.exit(1)
812

813
    logger.info("Sample code check is successful!")
814 815 816 817

    if args.mode == "cpu":
        # As cpu mode is also run with the GPU whl, so skip it in gpu mode.
        exec_gen_doc()