sampcd_processor.py 26.2 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
T
tianshuo78520a 已提交
14 15

import os
16
import sys
T
tianshuo78520a 已提交
17
import subprocess
18 19 20
import multiprocessing
import math
import platform
21 22 23
import inspect
import paddle
import paddle.fluid
Z
zhangchunle 已提交
24
import json
Z
zhangchunle 已提交
25
import re
26 27
"""
please make sure to run in the tools path
28
usage: python sample_test.py {arg1} 
29 30 31
arg1: the first arg defined running in gpu version or cpu version

for example, you can run cpu version python2 testing like this:
32 33 34

    python sampcd_processor.py cpu 

35
"""
T
tianshuo78520a 已提交
36 37 38


def find_all(srcstr, substr):
39
    """
40 41 42 43 44 45
    to find all desired substring in the source string
     and return their starting indices as a list

    Args:
        srcstr(str): the parent string
        substr(str): substr
46

47
    Returns:
48
        list: a list of the indices of the substrings
49
              found
50
    """
T
tianshuo78520a 已提交
51 52 53 54 55 56 57 58 59
    indices = []
    gotone = srcstr.find(substr)
    while (gotone != -1):
        indices.append(gotone)
        gotone = srcstr.find(substr, gotone + 1)
    return indices


def check_indent(cdline):
60
    """
61
    to check the indent of a given code line
62

63 64
    to get the number of starting blank chars,
    e.t. blankspaces and \t
65 66

    \t will be interpreted as 4 single blankspaces,
67
    e.t. '\t'='    '
68

69 70 71 72
    Args:
        cdline(str) : a single line of code from the source file

    Returns:
73
        int : the indent of the number of interpreted
74
             blankspaces
75
    """
T
tianshuo78520a 已提交
76 77 78 79 80 81 82 83 84 85 86
    indent = 0
    for c in cdline:
        if c == '\t':
            indent += 4
        elif c == ' ':
            indent += 1
        if c != ' ' and c != '\t':
            break
    return indent


Z
zhangchunle 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
def return_sample(htype, srccom, mode):
    """
    this function will return sample code or sample result.

    Args:
        htype(str): the type of hint banners, def/class/method.
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        mode(str): return mode. python/text.

    Returns:
        r: sample code or sample result. 
    """
    if htype == 'method':
        strings = '^ code-block:: %s\n(.*?)\n        [^ ]' % mode
        pattern = re.compile(strings, re.MULTILINE | re.DOTALL)
    else:
        strings = '^ code-block:: %s\n(.*?)\n    [^ ]' % mode
        pattern = re.compile(strings, re.MULTILINE | re.DOTALL)
    if pattern.search(srccom) == None:
        r = None
    else:
        r = pattern.search(srccom).group(1)
    return r


113 114 115
# srccom: raw comments in the source,including ''' and original indent
def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
    """
116 117 118 119 120 121 122 123 124
    Extract and run sample codes from source comment and
    the result will be returned.

    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
125

126
    Returns:
127
        result: True or False
128 129
    """

130 131
    result = True

132 133
    def sampcd_header_print(name, sampcd, htype, hname):
        """
134
        print hint banner headers.
135

136 137 138 139 140 141
        Args:
            name(str): the name of the API.
            sampcd(str): sample code string
            htype(str): the type of hint banners, def/class/method.
            hname(str): the name of the hint  banners , e.t. def hname.
            flushed.
142 143 144
        """
        print_header(htype, hname)
        print("Sample code ", str(y), " extracted for ", name, "   :")
145
        print(sampcd)
146 147 148
        print("----example code check----\n")
        print("executing sample code .....")
        print("execution result:")
149 150

    sampcd_begins = find_all(srccom, " code-block:: python")
Z
zhangchunle 已提交
151 152
    sampre_begins = find_all(srccom, " code-block:: text")

153
    if len(sampcd_begins) == 0:
154 155 156 157
        '''
        detect sample codes using >>> to format
        and consider this situation as wrong
        '''
Z
zhangchunle 已提交
158
        print_header(htype, hname)
159 160 161
        if srccom.find("Examples:") != -1:
            print("----example code check----\n")
            if srccom.find(">>>") != -1:
T
tianshuo78520a 已提交
162
                print(
163 164
                    "Deprecated sample code style:\n\n    Examples:\n\n        >>>codeline\n        >>>codeline\n\n\n ",
                    "Please use '.. code-block:: python' to ",
T
tianshuo78520a 已提交
165
                    "format sample code.\n")
166
                result = False
T
tianshuo78520a 已提交
167
        else:
168 169
            print("Error: No sample code!\n")
            result = False
Z
zhangchunle 已提交
170 171 172 173
    if len(sampcd_begins) != len(sampre_begins) and hname not in wlist_return:
        print_header(htype, hname)
        if len(sampre_begins) == 0:
            print("Error: Cannot find the return result of the sample code.")
174
        else:
Z
zhangchunle 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
            print("Error: Found %s result but %s python file." %
                  (len(sampre_begins), len(sampcd_begins)))
            print(
                "If you think the sample code of this api is not suitable for the return result, please add the white list in FIle: tools/wlist_return.json first. And you must have one TPM(saxon-zh or swtkiwi or Boyan-Liu) approve for the white list."
            )
        result = False
    else:
        for y in range(1, len(sampcd_begins) + 1):
            sampcd_begin = sampcd_begins[y - 1]
            sampcd = srccom[sampcd_begin:]
            sampcd = return_sample(htype, sampcd, 'python')
            if sampcd == None:
                sampcd = srccom[sampcd_begin + len(" code-block:: python") + 1:]
            sampcd = sampcd.split("\n")
            # remove starting empty lines
            while sampcd[0].replace(' ', '').replace('\t', '') == '':
                sampcd.pop(0)

            # the minimum indent, which is the indent of the first
            # non-empty line
            min_indent = check_indent(sampcd[0])
            sampcd_to_write = []
            for i in range(0, len(sampcd)):
                cdline = sampcd[i]
                # handle empty lines or those only with spaces/tabs
                if cdline.strip() == '':
                    continue
                this_indent = check_indent(cdline)
                if this_indent < min_indent:
                    break
                else:
                    cdline = cdline.replace('\t', '    ')
                    sampcd_to_write.append(cdline[min_indent:])

            sampcd = '\n'.join(sampcd_to_write)
            if sys.argv[1] == "cpu":
                sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd
            if sys.argv[1] == "gpu":
                sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n' + sampcd

            if len(sampcd_begins) > 1:
                tfname = name + "_example_" + str(y) + ".py"
            else:
                tfname = name + "_example" + ".py"
            tempf = open("samplecode_temp/" + tfname, 'w')
            tempf.write(sampcd)
            tempf.close()
            if platform.python_version()[0] == "2":
                cmd = ["python", "samplecode_temp/" + tfname]
            elif platform.python_version()[0] == "3":
                cmd = ["python3", "samplecode_temp/" + tfname]
            else:
                print("Error: fail to parse python version!")
                result = False
                exit(1)
            subprc = subprocess.Popen(
                cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            output, error = subprc.communicate()
            if output == "" and len(sampre_begins) != 0:
                print_header(htype, hname)
                print(
                    "Error: Your sample code have returned a result, but execute sample code don't get result!"
                )
                print(
                    "If you think the sample code of this api is not suitable for the return result, please add the white list in FIle: tools/wlist_return.json first. And you must have one TPM(saxon-zh or swtkiwi or Boyan-Liu) approve for the white list."
                )
                result = False
            elif len(sampcd_begins) == len(sampre_begins):
                sampre_begin = sampre_begins[y - 1]
                sampre = return_sample(htype, srccom[sampre_begin:], 'text')
                if output != sampre:
                    print_header(htype, hname)
                    print(
                        "Error: Mistake found in the return result of sample code."
                    )
                    print("There maybe three reasons for this error:")
                    print(
                        "1. The input of the sample code is a random number.Please add the white list in FIle: tools/return_white_list.txt first .And you must have one TPM(saxon-zh or swtkiwi or Boyan-Liu) approve for the white list."
                    )
                    print(
                        "2. The return value of the sample code is incorrect. Please check the code and reset the return value."
                    )
                    print(
                        "3. If you think the sample code of this api is not suitable for the return result, please add the white list in FIle: tools/wlist_return.json first. And you must have one TPM(saxon-zh or swtkiwi or Boyan-Liu) approve for the white list."
                    )
                    result = False
            else:
                if name not in wlist_return:
                    print_header(htype, hname)
                    print(
                        "Error: If you think the sample code of this api is not suitable for the return result, please add the white list in FIle: tools/wlist_return.json first. And you must have one TPM(saxon-zh or swtkiwi or Boyan-Liu) approve for the white list."
                    )
                    result = False
            msg = "".join(output.decode(encoding='utf-8'))
            err = "".join(error.decode(encoding='utf-8'))
            if subprc.returncode != 0:
                print("\nSample code error found in ", name, ":\n")
                sampcd_header_print(name, sampcd, htype, hname)
                print("subprocess return code: ", str(subprc.returncode))
                print("Error Raised from Sample Code ", name, " :\n")
                print(err)
                print(msg)
                result = False
278

279
    return result
T
tianshuo78520a 已提交
280 281 282


def single_defcom_extract(start_from, srcls, is_class_begin=False):
283
    """
284 285
    to extract a def function/class/method comments body

286
    Args:
287 288 289 290 291 292
        start_from(int): the line num of "def" header
        srcls(list): the source file in lines
        is_class_begin(bool): whether the start_from is a beginning a class. \
        For a sole class body itself may end up with its method if it has no
        docstring. But the body of \
        a common def function can only be ended up by a none-indented def/class
293

294 295 296
    Returns:
        string : the extracted comment body, inclusive of its quote marks.

297
    """
298

T
tianshuo78520a 已提交
299
    i = start_from
300 301 302
    fcombody = ""  # def comment body
    comstart = -1  # the starting line index of comment mark "'''" or """"""
    # if it is not -1, it indicates the loop is in the comment body
303 304
    comstyle = 0  # comment mark style ,comments quoted with ''' is coded as 1
    # comments quoted with """ is coded as 2
T
tianshuo78520a 已提交
305 306
    for x in range(i + 1, len(srcls)):
        if is_class_begin:
307
            if srcls[x].replace('\t', '    ').startswith('    def '):
T
tianshuo78520a 已提交
308
                break
309
        if srcls[x].startswith('def ') or srcls[x].startswith('class '):
T
tianshuo78520a 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
            break
        else:
            if (comstart == -1 and srcls[x].replace(" ", '').replace(
                    "\t", '').replace("\n", '').startswith("\"\"\"")):
                comstart = x
                comstyle = 2
                continue
            if (comstyle == 2 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\"\"\"")):
                break
            if (comstart == -1 and srcls[x].replace(" ", '').replace(
                    "\t", '').replace("\n", '').startswith("\'\'\'")):
                comstart = x
                comstyle = 1
                continue
            if (comstyle == 1 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\'\'\'")):
                break
            if (comstart !=
331
                    -1):  # when the comments start, begin to add line to fcombody
T
tianshuo78520a 已提交
332 333 334 335
                fcombody += srcls[x]
    return fcombody


336 337 338
def print_header(htype, name):
    print(htype, " name:", name)
    print("-----------------------")
339

T
tianshuo78520a 已提交
340

341
def srccoms_extract(srcfile, wlist):
342
    """
343 344 345 346 347 348 349 350 351
    Given a source file ``srcfile``, this function will
    extract its API(doc comments) and run sample codes in the
    API.

    Args:
        srcfile(file): the source file
        wlist(list): white list

    Returns:
352
        result: True or False
353
    """
354
    process_result = True
T
tianshuo78520a 已提交
355
    srcc = srcfile.read()
356 357
    # 2. get defs and classes header line number
    # set file pointer to its beginning
T
tianshuo78520a 已提交
358
    srcfile.seek(0, 0)
359
    srcls = srcfile.readlines()  # source lines
360

361
    # 1. fetch__all__ list
T
tianshuo78520a 已提交
362
    allidx = srcc.find("__all__")
363 364 365 366 367 368
    srcfile_new = srcfile.name
    srcfile_new = srcfile_new.replace('.py', '')
    srcfile_list = srcfile_new.split('/')
    srcfile_str = ''
    for i in range(4, len(srcfile_list)):
        srcfile_str = srcfile_str + srcfile_list[i] + '.'
369
    if allidx != -1:
T
tianshuo78520a 已提交
370
        alllist = []
371 372
        # get all list for layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
373
            for ai in range(0, len(srcls)):
374
                if srcls[ai].startswith("__all__"):
T
tianshuo78520a 已提交
375 376
                    lb = srcls[ai].find('[')
                    rb = srcls[ai].find(']')
377
                    if lb == -1:
T
tianshuo78520a 已提交
378 379 380 381
                        continue
                    allele = srcls[ai][lb + 1:rb].replace("'", '').replace(
                        " ", '').replace("\"", '')
                    alllist.append(allele)
382 383
            if '' in alllist:
                alllist.remove('')
T
tianshuo78520a 已提交
384 385 386 387 388 389 390 391 392 393 394 395
        else:
            alllist_b = allidx + len("__all__")
            allstr = srcc[alllist_b + srcc[alllist_b:].find("[") + 1:alllist_b +
                          srcc[alllist_b:].find("]")]
            allstr = allstr.replace("\n", '').replace(" ", '').replace(
                "'", '').replace("\"", '')
            alllist = allstr.split(',')
            if '' in alllist:
                alllist.remove('')
        api_alllist_count = len(alllist)
        api_count = 0
        handled = []
396 397
        # get src contents in layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
398 399 400 401 402 403 404
            for i in range(0, len(srcls)):
                if srcls[i].find("__doc__") != -1:
                    opname = srcls[i][:srcls[i].find("__doc__") - 1]
                    if opname in wlist:
                        continue
                    comstart = i
                    for j in range(i, len(srcls)):
405
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
406 407 408 409
                            comstart = i
                    opcom = ""
                    for j in range(comstart + 1, len(srcls)):
                        opcom += srcls[j]
410
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
411 412
                            break
                    api_count += 1
413
                    handled.append(
414 415 416
                        opname)  # ops.py also has normal formatted functions
                    # use list 'handled'  to mark the functions have been handled here
                    # which will be ignored in the following step
Z
zhangchunle 已提交
417

T
tianshuo78520a 已提交
418
        for i in range(0, len(srcls)):
419
            if srcls[i].startswith(
420
                    'def '):  # a function header is detected in line i
T
tianshuo78520a 已提交
421
                f_header = srcls[i].replace(" ", '')
422
                fn = f_header[len('def'):f_header.find('(')]  # function name
423 424
                if "%s%s" % (srcfile_str, fn) not in methods:
                    continue
T
tianshuo78520a 已提交
425 426 427 428
                if fn in handled:
                    continue
                if fn in alllist:
                    api_count += 1
429
                    if fn in wlist or fn + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
430 431
                        continue
                    fcombody = single_defcom_extract(i, srcls)
432 433 434
                    if fcombody == "":  # if no comment
                        print("WARNING: no comments in function ", fn,
                              ", but it deserves.")
T
tianshuo78520a 已提交
435 436
                        continue
                    else:
437 438
                        if not sampcd_extract_and_run(fcombody, fn, "def", fn):
                            process_result = False
T
tianshuo78520a 已提交
439 440
            if srcls[i].startswith('class '):
                c_header = srcls[i].replace(" ", '')
441
                cn = c_header[len('class'):c_header.find('(')]  # class name
442 443
                if '%s%s' % (srcfile_str, cn) not in methods:
                    continue
T
tianshuo78520a 已提交
444 445 446 447
                if cn in handled:
                    continue
                if cn in alllist:
                    api_count += 1
448
                    if cn in wlist or cn + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
449
                        continue
450
                    # class comment
T
tianshuo78520a 已提交
451
                    classcom = single_defcom_extract(i, srcls, True)
452
                    if classcom != "":
453 454
                        if not sampcd_extract_and_run(classcom, cn, "class",
                                                      cn):
455

456
                            process_result = False
T
tianshuo78520a 已提交
457
                    else:
458 459 460
                        print("WARNING: no comments in class itself ", cn,
                              ", but it deserves.\n")
                    # handling methods in class bodies
T
tianshuo78520a 已提交
461 462
                    for x in range(
                            i + 1,
463
                            len(srcls)):  # from the next line of class header
T
tianshuo78520a 已提交
464 465 466 467
                        if (srcls[x].startswith('def ') or
                                srcls[x].startswith('class ')):
                            break
                        else:
468
                            # member method def header
469
                            srcls[x] = srcls[x].replace('\t', '    ')
T
tianshuo78520a 已提交
470
                            if (srcls[x].startswith(
471
                                    '    def ')):  # detect a mehtod header..
T
tianshuo78520a 已提交
472 473 474
                                thisl = srcls[x]
                                indent = len(thisl) - len(thisl.lstrip())
                                mn = thisl[indent + len('def '):thisl.find(
475 476
                                    '(')]  # method name
                                name = cn + "." + mn  # full name
477 478 479 480
                                if '%s%s' % (
                                        srcfile_str, name
                                ) not in methods:  # class method not in api.spec 
                                    continue
T
tianshuo78520a 已提交
481 482
                                if mn.startswith('_'):
                                    continue
483
                                if name in wlist or name + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
484
                                    continue
485 486 487 488 489
                                thismethod = [thisl[indent:]
                                              ]  # method body lines
                                # get all the lines of a single method body
                                # into thismethod(list)
                                # and send it to single_defcom_extract
T
tianshuo78520a 已提交
490
                                for y in range(x + 1, len(srcls)):
491
                                    srcls[y] = srcls[y].replace('\t', '    ')
T
tianshuo78520a 已提交
492 493
                                    if (srcls[y].startswith('def ') or
                                            srcls[y].startswith('class ')):
494
                                        # end of method
T
tianshuo78520a 已提交
495
                                        break
496 497
                                    elif srcls[y].startswith('    def '):
                                        # end of method
T
tianshuo78520a 已提交
498 499 500 501 502
                                        break
                                    else:
                                        thismethod.append(srcls[y][indent:])
                                thismtdcom = single_defcom_extract(0,
                                                                   thismethod)
503
                                if thismtdcom != "":
504 505 506 507
                                    if not sampcd_extract_and_run(
                                            thismtdcom, name, "method", name):
                                        process_result = False
    return process_result
T
tianshuo78520a 已提交
508 509


510
def test(file_list):
511
    process_result = True
512
    for file in file_list:
513 514 515 516
        with open(file, 'r') as src:
            if not srccoms_extract(src, wlist):
                process_result = False
    return process_result
517 518


519
def get_filenames():
520
    '''
521
    this function will get the modules that pending for check.
522 523 524 525 526 527 528 529 530

    Returns:

        list: the modules pending for check .

    '''
    filenames = []
    global methods
    methods = []
531 532
    get_incrementapi()
    API_spec = 'dev_pr_diff_api.spec'
533 534
    with open(API_spec) as f:
        for line in f.readlines():
535
            api = line.replace('\n', '')
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
            try:
                module = eval(api).__module__
            except AttributeError:
                continue
            if len(module.split('.')) > 2:
                filename = '../python/'
                module_py = '%s.py' % module.split('.')[-1]
                for i in range(0, len(module.split('.')) - 1):
                    filename = filename + '%s/' % module.split('.')[i]
                filename = filename + module_py
            else:
                print("\n----Exception in get api filename----\n")
                print("\n" + api + 'module is ' + module + "\n")
            if filename not in filenames:
                filenames.append(filename)
            # get all methods
            method = ''
            if inspect.isclass(eval(api)):
                name = api.split('.')[-1]
            elif inspect.isfunction(eval(api)):
                name = api.split('.')[-1]
            elif inspect.ismethod(eval(api)):
                name = '%s.%s' % (api.split('.')[-2], api.split('.')[-1])
            else:
                name = ''
                print("\n----Exception in get api methods----\n")
                print("\n" + line + "\n")
                print("\n" + api + ' method is None!!!' + "\n")
            for j in range(2, len(module.split('.'))):
                method = method + '%s.' % module.split('.')[j]
            method = method + name
            if method not in methods:
                methods.append(method)
569
    os.remove(API_spec)
570 571 572
    return filenames


573 574 575 576 577 578 579 580 581 582 583 584
def get_incrementapi():
    '''
    this function will get the apis that difference between API_DEV.spec and API_PR.spec.
    '''

    def get_api_md5(path):
        api_md5 = {}
        API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")),
                              path)
        with open(API_spec) as f:
            for line in f.readlines():
                api = line.split(' ', 1)[0]
Z
zhangchunle 已提交
585 586 587 588
                if line.find('document') != -1:
                    md5 = line.split("'document', ")[1].replace(
                        ')', '').replace('\n', '')
                    api_md5[api] = md5
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
        return api_md5

    dev_api = get_api_md5('paddle/fluid/API_DEV.spec')
    pr_api = get_api_md5('paddle/fluid/API_PR.spec')
    with open('dev_pr_diff_api.spec', 'w') as f:
        for key in pr_api:
            if key in dev_api:
                if dev_api[key] != pr_api[key]:
                    f.write(key)
                    f.write('\n')
            else:
                f.write(key)
                f.write('\n')


604 605 606
# only white on CPU
gpu_not_white = [
    "deformable_conv", "cuda_places", "CUDAPinnedPlace", "CUDAPlace",
607
    "cuda_profiler", 'DGCMomentumOptimizer'
608
]
609

Z
zhangchunle 已提交
610

Z
zhangchunle 已提交
611
def get_wlist(file_name):
Z
zhangchunle 已提交
612 613 614
    '''
    this function will get the white list of API.

Z
zhangchunle 已提交
615 616 617
    Args:
        file_name(file): white file name.

Z
zhangchunle 已提交
618 619 620 621 622 623
    Returns:

        wlist: a list of API that should not trigger the example check .

    '''
    wlist = []
Z
zhangchunle 已提交
624
    with open(file_name, 'r') as load_f:
Z
zhangchunle 已提交
625 626 627 628 629 630
        load_dict = json.load(load_f)
        for key in load_dict:
            wlist = wlist + load_dict[key]
    return wlist


Z
zhangchunle 已提交
631 632 633
wlist = get_wlist('wlist.json')

wlist_return = get_wlist('wlist_return.json')
634 635

if len(sys.argv) < 2:
636
    print("Error: inadequate number of arguments")
637 638 639 640 641
    print('''If you are going to run it on 
        "CPU: >>> python sampcd_processor.py cpu
        "GPU: >>> python sampcd_processor.py gpu
        ''')
    sys.exit("lack arguments")
T
tianshuo78520a 已提交
642
else:
643 644 645 646
    if sys.argv[1] == "gpu":
        for _gnw in gpu_not_white:
            wlist.remove(_gnw)
    elif sys.argv[1] != "cpu":
647 648
        print("Unrecognized argument:'", sys.argv[1], "' , 'cpu' or 'gpu' is ",
              "desired\n")
649
        sys.exit("Invalid arguments")
650 651
    print("API check -- Example Code")
    print("sample_test running under python", platform.python_version())
652 653
    if not os.path.isdir("./samplecode_temp"):
        os.mkdir("./samplecode_temp")
654
    cpus = multiprocessing.cpu_count()
655 656 657 658 659 660 661
    filenames = get_filenames()
    if len(filenames) == 0:
        print("-----API_PR.spec is the same as API_DEV.spec-----")
        exit(0)
    elif '../python/paddle/fluid/core_avx.py' in filenames:
        filenames.remove('../python/paddle/fluid/core_avx.py')
    print("API_PR is diff from API_DEV: %s" % filenames)
662
    one_part_filenum = int(math.ceil(len(filenames) / cpus))
663 664
    if one_part_filenum == 0:
        one_part_filenum = 1
665 666 667 668
    divided_file_list = [
        filenames[i:i + one_part_filenum]
        for i in range(0, len(filenames), one_part_filenum)
    ]
669

670 671
    po = multiprocessing.Pool()
    results = po.map_async(test, divided_file_list)
672 673
    po.close()
    po.join()
674

675
    result = results.get()
676

677
    # delete temp files
678 679 680 681 682
    for root, dirs, files in os.walk("./samplecode_temp"):
        for fntemp in files:
            os.remove("./samplecode_temp/" + fntemp)
    os.rmdir("./samplecode_temp")

683
    print("----------------End of the Check--------------------")
684 685 686 687 688
    for temp in result:
        if not temp:
            print("Mistakes found in sample codes")
            exit(1)
    print("Sample code check is successful!")