sampcd_processor.py 25.9 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15
"""
please make sure to run in the tools path
16
usage: python sample_test.py {cpu or gpu}
17 18
    {cpu or gpu}: running in cpu version or gpu version

19
for example, you can run cpu version testing like this:
20

21
    python sampcd_processor.py cpu
T
tianshuo78520a 已提交
22

23
"""
24 25 26
import argparse
import inspect
import logging
27
import multiprocessing
28
import os
29
import platform
30
import re
31 32 33
import shutil
import subprocess
import sys
34
import time
T
tianshuo78520a 已提交
35

36 37 38
logger = logging.getLogger()
if logger.handlers:
    console = logger.handlers[
39 40
        0
    ]  # we assume the first handler is the one we want to configure
41
else:
42
    console = logging.StreamHandler(stream=sys.stderr)
43
    logger.addHandler(console)
44
console.setFormatter(logging.Formatter("%(message)s"))
45 46

RUN_ON_DEVICE = 'cpu'
47
SAMPLE_CODE_TEST_CAPACITY = set()
48 49 50 51 52 53
GPU_ID = 0
whl_error = []
API_DEV_SPEC_FN = 'paddle/fluid/API_DEV.spec'
API_PR_SPEC_FN = 'paddle/fluid/API_PR.spec'
API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec'
SAMPLECODE_TEMPDIR = 'samplecode_temp'
54 55 56 57 58 59 60 61 62
ENV_KEY_CODES_FRONTEND = 'CODES_INSERTED_INTO_FRONTEND'
ENV_KEY_TEST_CAPACITY = 'SAMPLE_CODE_TEST_CAPACITY'
SUMMARY_INFO = {
    'success': [],
    'failed': [],
    'skiptest': [],
    'nocodes': [],
    # ... required not-match
}
63

T
tianshuo78520a 已提交
64 65

def find_all(srcstr, substr):
66
    """
67 68 69 70 71 72
    to find all desired substring in the source string
     and return their starting indices as a list

    Args:
        srcstr(str): the parent string
        substr(str): substr
73

74
    Returns:
75
        list: a list of the indices of the substrings
76
              found
77
    """
T
tianshuo78520a 已提交
78 79
    indices = []
    gotone = srcstr.find(substr)
80
    while gotone != -1:
T
tianshuo78520a 已提交
81 82 83 84 85
        indices.append(gotone)
        gotone = srcstr.find(substr, gotone + 1)
    return indices


86 87 88 89 90 91 92 93
def find_last_future_line_end(cbstr):
    """
    find the last `__future__` line.

    Args:
        docstr(str): docstring
    Return:
        index of the line end or None.
94
    """
95 96 97 98 99 100 101 102 103 104 105 106
    pat = re.compile('__future__.*\n')
    lastmo = None
    it = re.finditer(pat, cbstr)
    while True:
        try:
            lastmo = next(it)
        except StopIteration:
            break
    if lastmo:
        return lastmo.end()
    else:
        return None
107 108


109 110 111 112 113 114
def extract_code_blocks_from_docstr(docstr):
    """
    extract code-blocks from the given docstring.

    DON'T include the multiline-string definition in code-blocks.
    The *Examples* section must be the last.
115

116
    Args:
117 118
        docstr(str): docstring
    Return:
119
        code_blocks: A list of code-blocks, indent removed.
120 121 122 123 124 125 126 127
                     element {'name': the code-block's name, 'id': sequence id.
                              'codes': codes, 'required': 'gpu'}
    """
    code_blocks = []

    mo = re.search(r"Examples:", docstr)
    if mo is None:
        return code_blocks
128
    ds_list = docstr[mo.start() :].replace("\t", '    ').split("\n")
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    lastlineindex = len(ds_list) - 1

    cb_start_pat = re.compile(r"code-block::\s*python")
    cb_param_pat = re.compile(r"^\s*:(\w+):\s*(\S*)\s*$")
    cb_required_pat = re.compile(r"^\s*#\s*require[s|d]\s*:\s*(\S+)\s*$")

    cb_info = {}
    cb_info['cb_started'] = False
    cb_info['cb_cur'] = []
    cb_info['cb_cur_indent'] = -1
    cb_info['cb_cur_name'] = None
    cb_info['cb_cur_seq_id'] = 0
    cb_info['cb_required'] = None

    def _cb_started():
        # nonlocal cb_started, cb_cur_name, cb_required, cb_cur_seq_id
        cb_info['cb_started'] = True
        cb_info['cb_cur_seq_id'] += 1
        cb_info['cb_cur_name'] = None
        cb_info['cb_required'] = None

    def _append_code_block():
        # nonlocal code_blocks, cb_cur, cb_cur_name, cb_cur_seq_id, cb_required
152 153 154 155 156 157 158 159
        code_blocks.append(
            {
                'codes': inspect.cleandoc("\n".join(cb_info['cb_cur'])),
                'name': cb_info['cb_cur_name'],
                'id': cb_info['cb_cur_seq_id'],
                'required': cb_info['cb_required'],
            }
        )
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187

    for lineno, linecont in enumerate(ds_list):
        if re.search(cb_start_pat, linecont):
            if not cb_info['cb_started']:
                _cb_started()
                continue
            else:
                # cur block end
                if len(cb_info['cb_cur']):
                    _append_code_block()
                _cb_started()  # another block started
                cb_info['cb_cur_indent'] = -1
                cb_info['cb_cur'] = []
        else:
            if cb_info['cb_started']:
                # handle the code-block directive's options
                mo_p = cb_param_pat.match(linecont)
                if mo_p:
                    if mo_p.group(1) == 'name':
                        cb_info['cb_cur_name'] = mo_p.group(2)
                    continue
                # read the required directive
                mo_r = cb_required_pat.match(linecont)
                if mo_r:
                    cb_info['cb_required'] = mo_r.group(1)
                # docstring end
                if lineno == lastlineindex:
                    mo = re.search(r"\S", linecont)
188 189 190
                    if (
                        mo is not None
                        and cb_info['cb_cur_indent'] <= mo.start()
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
                    ):
                        cb_info['cb_cur'].append(linecont)
                    if len(cb_info['cb_cur']):
                        _append_code_block()
                    break
                # check indent for cur block start and end.
                mo = re.search(r"\S", linecont)
                if mo is None:
                    continue
                if cb_info['cb_cur_indent'] < 0:
                    # find the first non empty line
                    cb_info['cb_cur_indent'] = mo.start()
                    cb_info['cb_cur'].append(linecont)
                else:
                    if cb_info['cb_cur_indent'] <= mo.start():
                        cb_info['cb_cur'].append(linecont)
                    else:
                        if linecont[mo.start()] == '#':
                            continue
                        else:
                            # block end
                            if len(cb_info['cb_cur']):
                                _append_code_block()
                            cb_info['cb_started'] = False
                            cb_info['cb_cur_indent'] = -1
                            cb_info['cb_cur'] = []
    return code_blocks


def get_test_capacity():
    """
    collect capacities and set to SAMPLE_CODE_TEST_CAPACITY
    """
    global SAMPLE_CODE_TEST_CAPACITY  # write
    global ENV_KEY_TEST_CAPACITY, RUN_ON_DEVICE  # readonly
    if ENV_KEY_TEST_CAPACITY in os.environ:
        for r in os.environ[ENV_KEY_TEST_CAPACITY].split(','):
            rr = r.strip().lower()
            if r:
                SAMPLE_CODE_TEST_CAPACITY.add(rr)
    if 'cpu' not in SAMPLE_CODE_TEST_CAPACITY:
        SAMPLE_CODE_TEST_CAPACITY.add('cpu')
233

234 235 236 237 238
    if RUN_ON_DEVICE:
        SAMPLE_CODE_TEST_CAPACITY.add(RUN_ON_DEVICE)


def is_required_match(requirestr, cbtitle='not-specified'):
239
    """
240
    search the required instruction in the code-block, and check it match the current running environment.
241

242 243 244 245 246 247 248 249 250 251 252
    environment values of equipped: cpu, gpu, xpu, distributed, skip
    the 'skip' is the special flag to skip the test, so is_required_match will return False directly.

    Args:
        requirestr(str): the required string.
        cbtitle(str): the title of the code-block.
    returns:
        True - yes, matched
        False - not match
        None - skipped  # trick
    """
253
    global SAMPLE_CODE_TEST_CAPACITY, RUN_ON_DEVICE  # readonly
254
    requires = {'cpu'}
255 256 257 258 259
    if requirestr:
        for r in requirestr.split(','):
            rr = r.strip().lower()
            if rr:
                requires.add(rr)
260 261
    else:
        requires.add(RUN_ON_DEVICE)
262 263 264 265
    if 'skip' in requires or 'skiptest' in requires:
        logger.info('%s: skipped', cbtitle)
        return None

266 267 268 269
    if all(
        [
            k in SAMPLE_CODE_TEST_CAPACITY
            for k in requires
270
            if k not in ['skip', 'skiptest']
271 272
        ]
    ):
273 274
        return True

275 276 277 278 279 280
    logger.info(
        '%s: the equipments [%s] not match the required [%s].',
        cbtitle,
        ','.join(SAMPLE_CODE_TEST_CAPACITY),
        ','.join(requires),
    )
281 282 283 284 285 286 287 288 289 290
    return False


def insert_codes_into_codeblock(codeblock, apiname='not-specified'):
    """
    insert some codes in the frontend and backend into the code-block.
    """
    global ENV_KEY_CODES_FRONTEND, GPU_ID, RUN_ON_DEVICE  # readonly
    inserted_codes_f = ''
    inserted_codes_b = ''
291 292 293 294
    if (
        ENV_KEY_CODES_FRONTEND in os.environ
        and os.environ[ENV_KEY_CODES_FRONTEND]
    ):
295 296 297
        inserted_codes_f = os.environ[ENV_KEY_CODES_FRONTEND]
    else:
        cpu_str = '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = ""\n'
298 299 300 301 302
        gpu_str = (
            '\nimport os\nos.environ["CUDA_VISIBLE_DEVICES"] = "{}"\n'.format(
                GPU_ID
            )
        )
303 304
        if 'required' in codeblock and codeblock['required']:
            if codeblock['required'] == 'cpu':
305 306 307 308 309 310 311 312 313
                inserted_codes_f = cpu_str
            elif codeblock['required'] == 'gpu':
                inserted_codes_f = gpu_str
        else:
            if RUN_ON_DEVICE == "cpu":
                inserted_codes_f = cpu_str
            elif RUN_ON_DEVICE == "gpu":
                inserted_codes_f = gpu_str
    inserted_codes_b = '\nprint("{}\'s sample code (name:{}, id:{}) is executed successfully!")'.format(
314 315
        apiname, codeblock['name'], codeblock['id']
    )
316 317 318 319

    cb = codeblock['codes']
    last_future_line_end = find_last_future_line_end(cb)
    if last_future_line_end:
320 321 322 323 324 325
        return (
            cb[:last_future_line_end]
            + inserted_codes_f
            + cb[last_future_line_end:]
            + inserted_codes_b
        )
326 327
    else:
        return inserted_codes_f + cb + inserted_codes_b
T
tianshuo78520a 已提交
328 329


330 331 332
def sampcd_extract_to_file(srccom, name, htype="def", hname=""):
    """
    Extract sample codes from __doc__, and write them to files.
333

334 335 336 337 338 339
    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
340

341 342 343
    Returns:
        sample_code_filenames(list of str)
    """
344 345
    global GPU_ID, RUN_ON_DEVICE, SAMPLECODE_TEMPDIR  # readonly
    global SUMMARY_INFO  # update
346

347 348 349
    codeblocks = extract_code_blocks_from_docstr(srccom)
    if len(codeblocks) == 0:
        SUMMARY_INFO['nocodes'].append(name)
350
        # detect sample codes using >>> to format and consider this situation as wrong
351 352
        logger.info(htype + " name:" + name)
        logger.info("-----------------------")
353
        if srccom.find("Examples:") != -1:
354
            logger.info("----example code check----")
355
            if srccom.find(">>>") != -1:
356 357
                logger.warning(
                    r"""Deprecated sample code style:
358 359 360 361
    Examples:
        >>>codeline
        >>>codeline

362 363
Please use '.. code-block:: python' to format the sample code."""
                )
364
                return []
T
tianshuo78520a 已提交
365
        else:
366 367 368
            logger.error(
                "Error: No sample code found! Please check if the API comment contais string 'Examples:' correctly"
            )
369
            return []
370

371
    sample_code_filenames = []
372 373 374 375 376 377 378
    for y, cb in enumerate(codeblocks):
        matched = is_required_match(cb['required'], name)
        # matched has three states:
        # True - please execute it;
        # None - no sample code found;
        # False - it need other special equipment or environment.
        # so, the following conditional statements are intentionally arranged.
379
        if matched:
380
            tfname = os.path.join(
381 382
                SAMPLECODE_TEMPDIR,
                '{}_example{}'.format(
383
                    name,
384
                    '.py' if len(codeblocks) == 1 else f'_{y + 1}.py',
385 386
                ),
            )
387 388 389 390 391
            with open(tfname, 'w') as tempf:
                sampcd = insert_codes_into_codeblock(cb, name)
                tempf.write(sampcd)
            sample_code_filenames.append(tfname)
        elif matched is None:
392 393 394 395 396
            logger.info(
                '{}\' code block (name:{}, id:{}) is skipped.'.format(
                    name, cb['name'], cb['id']
                )
            )
397
            SUMMARY_INFO['skiptest'].append("{}-{}".format(name, cb['id']))
398
        elif not matched:
399
            logger.info(
400 401 402 403 404 405 406 407
                '{}\' code block (name:{}, id:{}) required({}) not match capacity({}).'.format(
                    name,
                    cb['name'],
                    cb['id'],
                    cb['required'],
                    SAMPLE_CODE_TEST_CAPACITY,
                )
            )
408 409 410 411
            if cb['required'] not in SUMMARY_INFO:
                SUMMARY_INFO[cb['required']] = []
            SUMMARY_INFO[cb['required']].append("{}-{}".format(name, cb['id']))

412 413 414
    return sample_code_filenames


415 416
def execute_samplecode(tfname):
    """
417
    Execute a sample-code test
418 419

    Args:
420
        tfname: the filename of the sample code
421

422 423 424
    Returns:
        result: success or not
        tfname: same as the input argument
425 426
        msg: the stdout output of the sample code executing
        time: time consumed by sample code
427
    """
428 429
    result = True
    msg = None
T
tianshuo78520a 已提交
430
    if platform.python_version()[0] in ["3"]:
431 432
        cmd = [sys.executable, tfname]
    else:
433
        logger.error("Error: fail to parse python version!")
434
        result = False
435
        sys.exit(1)
436

437 438 439
    logger.info("----example code check----")
    logger.info("executing sample code: %s", tfname)
    start_time = time.time()
440 441 442
    subprc = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
443 444 445
    output, error = subprc.communicate()
    msg = "".join(output.decode(encoding='utf-8'))
    err = "".join(error.decode(encoding='utf-8'))
446
    end_time = time.time()
447 448

    if subprc.returncode != 0:
449
        with open(tfname, 'r') as f:
450 451
            logger.warning(
                """Sample code error found in %s:
452 453 454 455 456 457 458
-----------------------
%s
-----------------------
subprocess return code: %d
Error Raised from Sample Code:
stderr: %s
stdout: %s
459 460 461 462 463 464 465
""",
                tfname,
                f.read(),
                subprc.returncode,
                err,
                msg,
            )
466
        logger.info("----example code check failed----")
467 468
        result = False
    else:
469
        logger.info("----example code check success----")
470 471

    # msg is the returned code execution report
472
    return result, tfname, msg, end_time - start_time
T
tianshuo78520a 已提交
473 474


475
def get_filenames(full_test=False):
476
    '''
477
    this function will get the sample code files that pending for check.
478

479 480 481
    Args:
        full_test: the full apis or the increment

482 483
    Returns:

484
        dict: the sample code files pending for check .
485 486

    '''
487
    global whl_error
488
    import paddle  # noqa: F401
489
    import paddle.static.quantization  # noqa: F401
490

491
    whl_error = []
492
    if full_test:
493
        get_full_api_from_pr_spec()
494 495
    else:
        get_incrementapi()
496 497
    all_sample_code_filenames = {}
    with open(API_DIFF_SPEC_FN) as f:
498
        for line in f.readlines():
499
            api = line.replace('\n', '')
500
            try:
501
                api_obj = eval(api)
502
            except AttributeError:
503
                whl_error.append(api)
504
                continue
505 506 507 508
            except SyntaxError:
                logger.warning('line:%s, api:%s', line, api)
                # paddle.Tensor.<lambda>
                continue
509
            if hasattr(api_obj, '__doc__') and api_obj.__doc__:
510
                sample_code_filenames = sampcd_extract_to_file(
511 512
                    api_obj.__doc__, api
                )
513 514 515
                for tfname in sample_code_filenames:
                    all_sample_code_filenames[tfname] = api
    return all_sample_code_filenames
516 517


518
def get_api_md5(path):
519 520 521 522 523
    """
    read the api spec file, and scratch the md5sum value of every api's docstring.

    Args:
        path: the api spec file. ATTENTION the path relative
524

525 526 527
    Returns:
        api_md5(dict): key is the api's real fullname, value is the md5sum.
    """
528
    api_md5 = {}
529 530 531
    API_spec = os.path.abspath(os.path.join(os.getcwd(), "..", path))
    if not os.path.isfile(API_spec):
        return api_md5
532 533
    pat = re.compile(r'\((paddle[^,]+)\W*document\W*([0-9a-z]{32})')
    patArgSpec = re.compile(
534 535
        r'^(paddle[^,]+)\s+\(ArgSpec.*document\W*([0-9a-z]{32})'
    )
536 537
    with open(API_spec) as f:
        for line in f.readlines():
538 539 540 541 542
            mo = pat.search(line)
            if not mo:
                mo = patArgSpec.search(line)
            if mo:
                api_md5[mo.group(1)] = mo.group(2)
543 544 545
    return api_md5


546 547 548 549
def get_full_api():
    """
    get all the apis
    """
550
    global API_DIFF_SPEC_FN  # readonly
551
    from print_signatures import get_all_api_from_modulelist
552

553 554 555 556 557 558 559 560 561
    member_dict = get_all_api_from_modulelist()
    with open(API_DIFF_SPEC_FN, 'w') as f:
        f.write("\n".join(member_dict.keys()))


def get_full_api_by_walk():
    """
    get all the apis
    """
562
    global API_DIFF_SPEC_FN  # readonly
563
    from print_signatures import get_all_api
564

565 566
    apilist = get_all_api()
    with open(API_DIFF_SPEC_FN, 'w') as f:
567 568 569 570 571 572 573
        f.write("\n".join([ai[0] for ai in apilist]))


def get_full_api_from_pr_spec():
    """
    get all the apis
    """
574
    global API_PR_SPEC_FN, API_DIFF_SPEC_FN  # readonly
575 576 577 578 579 580
    pr_api = get_api_md5(API_PR_SPEC_FN)
    if len(pr_api):
        with open(API_DIFF_SPEC_FN, 'w') as f:
            f.write("\n".join(pr_api.keys()))
    else:
        get_full_api_by_walk()
581 582


583 584 585 586
def get_incrementapi():
    '''
    this function will get the apis that difference between API_DEV.spec and API_PR.spec.
    '''
587
    global API_DEV_SPEC_FN, API_PR_SPEC_FN, API_DIFF_SPEC_FN  # readonly
588 589 590
    dev_api = get_api_md5(API_DEV_SPEC_FN)
    pr_api = get_api_md5(API_PR_SPEC_FN)
    with open(API_DIFF_SPEC_FN, 'w') as f:
591 592 593
        for key in pr_api:
            if key in dev_api:
                if dev_api[key] != pr_api[key]:
594 595 596 597 598 599
                    logger.debug(
                        "%s in dev is %s, different from pr's %s",
                        key,
                        dev_api[key],
                        pr_api[key],
                    )
600 601 602
                    f.write(key)
                    f.write('\n')
            else:
603
                logger.debug("%s is not in dev", key)
604 605 606 607
                f.write(key)
                f.write('\n')


608 609 610 611 612
def exec_gen_doc():
    result = True
    cmd = ["bash", "document_preview.sh"]
    logger.info("----exec gen_doc----")
    start_time = time.time()
613 614 615
    subprc = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633
    output, error = subprc.communicate()
    msg = "".join(output.decode(encoding='utf-8'))
    err = "".join(error.decode(encoding='utf-8'))
    end_time = time.time()

    if subprc.returncode != 0:
        logger.info("----gen_doc msg----")
        logger.info(msg)
        logger.error("----gen_doc error msg----")
        logger.error(err)
        logger.error("----exec gen_doc failed----")
        result = False
    else:
        logger.info("----gen_doc msg----")
        logger.info(msg)
        logger.info("----exec gen_doc success----")

    for fn in [
634 635
        '/docs/en/develop/index_en.html',
        '/docs/zh/develop/index_cn.html',
636 637 638 639 640 641 642 643 644 645
    ]:
        if os.path.exists(fn):
            logger.info('%s exists.', fn)
        else:
            logger.error('%s not exists.', fn)

    # msg is the returned code execution report
    return result, msg, end_time - start_time


646 647 648 649 650 651
arguments = [
    # flags, dest, type, default, help
    ['--gpu_id', 'gpu_id', int, 0, 'GPU device id to use [0]'],
    ['--logf', 'logf', str, None, 'file for logging'],
    ['--threads', 'threads', int, 0, 'sub processes number'],
]
652

653 654 655 656 657 658 659 660 661 662 663

def parse_args():
    """
    Parse input arguments
    """
    global arguments
    parser = argparse.ArgumentParser(description='run Sample Code Test')
    # parser.add_argument('--cpu', dest='cpu_mode', action="store_true",
    #                     help='Use CPU mode (overrides --gpu)')
    # parser.add_argument('--gpu', dest='gpu_mode', action="store_true")
    parser.add_argument('--debug', dest='debug', action="store_true")
664
    parser.add_argument('--full-test', dest='full_test', action="store_true")
665
    parser.add_argument('mode', type=str, help='run on device', default='cpu')
666 667 668 669 670 671
    parser.add_argument(
        '--build-doc',
        dest='build_doc',
        action='store_true',
        help='build doc if need.',
    )
672
    for item in arguments:
673 674 675
        parser.add_argument(
            item[0], dest=item[1], help=item[4], type=item[2], default=item[3]
        )
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690

    if len(sys.argv) == 1:
        args = parser.parse_args(['cpu'])
        return args
    #    parser.print_help()
    #    sys.exit(1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    if args.debug:
        logger.setLevel(logging.DEBUG)
691 692
    else:
        logger.setLevel(logging.INFO)
693 694 695 696 697
    if args.logf:
        logfHandler = logging.FileHandler(args.logf)
        logfHandler.setFormatter(
            logging.Formatter(
                "%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"
698 699
            )
        )
700 701 702 703 704 705
        logger.addHandler(logfHandler)

    if args.mode == "gpu":
        GPU_ID = args.gpu_id
        logger.info("using GPU_ID %d", GPU_ID)
    elif args.mode != "cpu":
706 707 708
        logger.error(
            "Unrecognized argument:%s, 'cpu' or 'gpu' is desired.", args.mode
        )
709
        sys.exit("Invalid arguments")
710
    RUN_ON_DEVICE = args.mode
711
    get_test_capacity()
712
    logger.info("API check -- Example Code")
713 714 715
    logger.info(
        "sample_test running under python %s", platform.python_version()
    )
716 717 718 719 720 721 722 723

    if os.path.exists(SAMPLECODE_TEMPDIR):
        if not os.path.isdir(SAMPLECODE_TEMPDIR):
            os.remove(SAMPLECODE_TEMPDIR)
            os.mkdir(SAMPLECODE_TEMPDIR)
    else:
        os.mkdir(SAMPLECODE_TEMPDIR)

724
    filenames = get_filenames(args.full_test)
725
    if len(filenames) == 0 and len(whl_error) == 0:
726
        logger.info("-----API_PR.spec is the same as API_DEV.spec-----")
727
        sys.exit(0)
728 729 730 731 732 733
    logger.info("API_PR is diff from API_DEV: %s", filenames)

    threads = multiprocessing.cpu_count()
    if args.threads:
        threads = args.threads
    po = multiprocessing.Pool(threads)
734
    results = po.map_async(execute_samplecode, filenames.keys())
735 736
    po.close()
    po.join()
737

738
    result = results.get()
739

740
    # delete temp files
741 742
    if not args.debug:
        shutil.rmtree(SAMPLECODE_TEMPDIR)
743

744 745
    stdout_handler = logging.StreamHandler(stream=sys.stdout)
    logger.addHandler(stdout_handler)
746
    logger.info("----------------End of the Check--------------------")
747
    if len(whl_error) != 0:
748 749 750 751 752 753 754
        logger.info("%s is not in whl.", whl_error)
        logger.info("")
        logger.info("Please check the whl package and API_PR.spec!")
        logger.info("You can follow these steps in order to generate API.spec:")
        logger.info("1. cd ${paddle_path}, compile paddle;")
        logger.info("2. pip install build/python/dist/(build whl package);")
        logger.info(
755 756 757
            "3. run 'python tools/print_signatures.py paddle > paddle/fluid/API.spec'."
        )
        for temp in result:
758
            if not temp[0]:
759 760 761
                logger.info(
                    "In addition, mistakes found in sample codes: %s", temp[1]
                )
762
        logger.info("----------------------------------------------------")
763
        sys.exit(1)
764
    else:
765
        timeovered_test = {}
766
        for temp in result:
767
            if not temp[0]:
768 769 770
                logger.info(
                    "In addition, mistakes found in sample codes: %s", temp[1]
                )
771 772 773 774 775 776 777
                SUMMARY_INFO['failed'].append(temp[1])
            else:
                SUMMARY_INFO['success'].append(temp[1])
            if temp[3] > 10:
                timeovered_test[temp[1]] = temp[3]

        if len(timeovered_test):
778 779 780
            logger.info(
                "%d sample codes ran time over 10s", len(timeovered_test)
            )
781 782
            if args.debug:
                for k, v in timeovered_test.items():
783
                    logger.info(f'{k} - {v}s')
784
        if len(SUMMARY_INFO['success']):
785 786 787
            logger.info(
                "%d sample codes ran success", len(SUMMARY_INFO['success'])
            )
788 789
        for k, v in SUMMARY_INFO.items():
            if k not in ['success', 'failed', 'skiptest', 'nocodes']:
790 791 792
                logger.info(
                    "%d sample codes required not match for %s", len(v), k
                )
793
        if len(SUMMARY_INFO['skiptest']):
794 795 796
            logger.info(
                "%d sample codes skipped", len(SUMMARY_INFO['skiptest'])
            )
797 798 799
            if args.debug:
                logger.info('\n'.join(SUMMARY_INFO['skiptest']))
        if len(SUMMARY_INFO['nocodes']):
800 801 802
            logger.info(
                "%d apis don't have sample codes", len(SUMMARY_INFO['nocodes'])
            )
803 804 805
            if args.debug:
                logger.info('\n'.join(SUMMARY_INFO['nocodes']))
        if len(SUMMARY_INFO['failed']):
806 807 808
            logger.info(
                "%d sample codes ran failed", len(SUMMARY_INFO['failed'])
            )
809 810 811 812
            logger.info('\n'.join(SUMMARY_INFO['failed']))
            logger.info(
                "Mistakes found in sample codes. Please recheck the sample codes."
            )
813
            sys.exit(1)
814

815
    logger.info("Sample code check is successful!")
816 817 818 819

    if args.mode == "cpu":
        # As cpu mode is also run with the GPU whl, so skip it in gpu mode.
        exec_gen_doc()