sampcd_processor.py 28.3 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
T
tianshuo78520a 已提交
14 15

import os
16
import sys
T
tianshuo78520a 已提交
17
import subprocess
18 19 20
import multiprocessing
import math
import platform
21 22 23
import inspect
import paddle
import paddle.fluid
24 25
"""
please make sure to run in the tools path
26
usage: python sample_test.py {arg1} 
27 28 29
arg1: the first arg defined running in gpu version or cpu version

for example, you can run cpu version python2 testing like this:
30 31 32

    python sampcd_processor.py cpu 

33
"""
T
tianshuo78520a 已提交
34 35 36


def find_all(srcstr, substr):
37
    """
38 39 40 41 42 43
    to find all desired substring in the source string
     and return their starting indices as a list

    Args:
        srcstr(str): the parent string
        substr(str): substr
44

45
    Returns:
46
        list: a list of the indices of the substrings
47
              found
48
    """
T
tianshuo78520a 已提交
49 50 51 52 53 54 55 56 57
    indices = []
    gotone = srcstr.find(substr)
    while (gotone != -1):
        indices.append(gotone)
        gotone = srcstr.find(substr, gotone + 1)
    return indices


def check_indent(cdline):
58
    """
59
    to check the indent of a given code line
60

61 62
    to get the number of starting blank chars,
    e.t. blankspaces and \t
63 64

    \t will be interpreted as 4 single blankspaces,
65
    e.t. '\t'='    '
66

67 68 69 70
    Args:
        cdline(str) : a single line of code from the source file

    Returns:
71
        int : the indent of the number of interpreted
72
             blankspaces
73
    """
T
tianshuo78520a 已提交
74 75 76 77 78 79 80 81 82 83 84
    indent = 0
    for c in cdline:
        if c == '\t':
            indent += 4
        elif c == ' ':
            indent += 1
        if c != ' ' and c != '\t':
            break
    return indent


85 86 87
# srccom: raw comments in the source,including ''' and original indent
def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
    """
88 89 90 91 92 93 94 95 96
    Extract and run sample codes from source comment and
    the result will be returned.

    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
97

98
    Returns:
99
        result: True or False
100 101
    """

102 103
    result = True

104 105
    def sampcd_header_print(name, sampcd, htype, hname):
        """
106
        print hint banner headers.
107

108 109 110 111 112 113
        Args:
            name(str): the name of the API.
            sampcd(str): sample code string
            htype(str): the type of hint banners, def/class/method.
            hname(str): the name of the hint  banners , e.t. def hname.
            flushed.
114 115 116
        """
        print_header(htype, hname)
        print("Sample code ", str(y), " extracted for ", name, "   :")
117
        print(sampcd)
118 119 120
        print("----example code check----\n")
        print("executing sample code .....")
        print("execution result:")
121 122

    sampcd_begins = find_all(srccom, " code-block:: python")
123 124
    if len(sampcd_begins) == 0:
        print_header(htype, hname)
125 126 127 128
        '''
        detect sample codes using >>> to format
        and consider this situation as wrong
        '''
129 130 131
        if srccom.find("Examples:") != -1:
            print("----example code check----\n")
            if srccom.find(">>>") != -1:
T
tianshuo78520a 已提交
132
                print(
133 134
                    "Deprecated sample code style:\n\n    Examples:\n\n        >>>codeline\n        >>>codeline\n\n\n ",
                    "Please use '.. code-block:: python' to ",
T
tianshuo78520a 已提交
135
                    "format sample code.\n")
136
                result = False
T
tianshuo78520a 已提交
137
        else:
138 139
            print("Error: No sample code!\n")
            result = False
T
tianshuo78520a 已提交
140 141 142

    for y in range(1, len(sampcd_begins) + 1):
        sampcd_begin = sampcd_begins[y - 1]
143
        sampcd = srccom[sampcd_begin + len(" code-block:: python") + 1:]
T
tianshuo78520a 已提交
144
        sampcd = sampcd.split("\n")
145
        # remove starting empty lines
T
tianshuo78520a 已提交
146 147
        while sampcd[0].replace(' ', '').replace('\t', '') == '':
            sampcd.pop(0)
148

149 150
        # the minimum indent, which is the indent of the first
        # non-empty line
T
tianshuo78520a 已提交
151 152 153 154
        min_indent = check_indent(sampcd[0])
        sampcd_to_write = []
        for i in range(0, len(sampcd)):
            cdline = sampcd[i]
155
            # handle empty lines or those only with spaces/tabs
T
tianshuo78520a 已提交
156 157 158
            if cdline.strip() == '':
                continue
            this_indent = check_indent(cdline)
159
            if this_indent < min_indent:
T
tianshuo78520a 已提交
160 161 162 163
                break
            else:
                cdline = cdline.replace('\t', '    ')
                sampcd_to_write.append(cdline[min_indent:])
164

T
tianshuo78520a 已提交
165
        sampcd = '\n'.join(sampcd_to_write)
166 167 168 169
        if sys.argv[1] == "cpu":
            sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd
        if sys.argv[1] == "gpu":
            sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n' + sampcd
170
        sampcd += '\nprint(' + '\"' + name + ' sample code is executed successfully!\")'
T
tianshuo78520a 已提交
171

172
        if len(sampcd_begins) > 1:
T
tianshuo78520a 已提交
173 174 175 176 177 178
            tfname = name + "_example_" + str(y) + ".py"
        else:
            tfname = name + "_example" + ".py"
        tempf = open("samplecode_temp/" + tfname, 'w')
        tempf.write(sampcd)
        tempf.close()
179 180 181 182 183
        if platform.python_version()[0] == "2":
            cmd = ["python", "samplecode_temp/" + tfname]
        elif platform.python_version()[0] == "3":
            cmd = ["python3", "samplecode_temp/" + tfname]
        else:
184 185
            print("Error: fail to parse python version!")
            result = False
186
            exit(1)
187

T
tianshuo78520a 已提交
188 189
        subprc = subprocess.Popen(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
190
        output, error = subprc.communicate()
191 192 193 194 195 196 197 198 199 200
        msg = "".join(output.decode(encoding='utf-8'))
        err = "".join(error.decode(encoding='utf-8'))

        if subprc.returncode != 0:
            print("\nSample code error found in ", name, ":\n")
            sampcd_header_print(name, sampcd, htype, hname)
            print("subprocess return code: ", str(subprc.returncode))
            print("Error Raised from Sample Code ", name, " :\n")
            print(err)
            print(msg)
201
            result = False
202
        # msg is the returned code execution report
203 204
        #os.remove("samplecode_temp/" + tfname)

205
    return result
T
tianshuo78520a 已提交
206 207 208


def single_defcom_extract(start_from, srcls, is_class_begin=False):
209
    """
210 211
    to extract a def function/class/method comments body

212
    Args:
213 214 215 216 217 218
        start_from(int): the line num of "def" header
        srcls(list): the source file in lines
        is_class_begin(bool): whether the start_from is a beginning a class. \
        For a sole class body itself may end up with its method if it has no
        docstring. But the body of \
        a common def function can only be ended up by a none-indented def/class
219

220 221 222
    Returns:
        string : the extracted comment body, inclusive of its quote marks.

223
    """
224

T
tianshuo78520a 已提交
225
    i = start_from
226 227 228
    fcombody = ""  # def comment body
    comstart = -1  # the starting line index of comment mark "'''" or """"""
    # if it is not -1, it indicates the loop is in the comment body
229 230
    comstyle = 0  # comment mark style ,comments quoted with ''' is coded as 1
    # comments quoted with """ is coded as 2
T
tianshuo78520a 已提交
231 232
    for x in range(i + 1, len(srcls)):
        if is_class_begin:
233
            if srcls[x].replace('\t', '    ').startswith('    def '):
T
tianshuo78520a 已提交
234
                break
235
        if srcls[x].startswith('def ') or srcls[x].startswith('class '):
T
tianshuo78520a 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
            break
        else:
            if (comstart == -1 and srcls[x].replace(" ", '').replace(
                    "\t", '').replace("\n", '').startswith("\"\"\"")):
                comstart = x
                comstyle = 2
                continue
            if (comstyle == 2 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\"\"\"")):
                break
            if (comstart == -1 and srcls[x].replace(" ", '').replace(
                    "\t", '').replace("\n", '').startswith("\'\'\'")):
                comstart = x
                comstyle = 1
                continue
            if (comstyle == 1 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\'\'\'")):
                break
            if (comstart !=
257
                    -1):  # when the comments start, begin to add line to fcombody
T
tianshuo78520a 已提交
258 259 260 261
                fcombody += srcls[x]
    return fcombody


262 263 264
def print_header(htype, name):
    print(htype, " name:", name)
    print("-----------------------")
265

T
tianshuo78520a 已提交
266

267
def srccoms_extract(srcfile, wlist):
268
    """
269 270 271 272 273 274 275 276 277
    Given a source file ``srcfile``, this function will
    extract its API(doc comments) and run sample codes in the
    API.

    Args:
        srcfile(file): the source file
        wlist(list): white list

    Returns:
278
        result: True or False
279
    """
280

281
    process_result = True
T
tianshuo78520a 已提交
282
    srcc = srcfile.read()
283 284
    # 2. get defs and classes header line number
    # set file pointer to its beginning
T
tianshuo78520a 已提交
285
    srcfile.seek(0, 0)
286
    srcls = srcfile.readlines()  # source lines
287

288
    # 1. fetch__all__ list
T
tianshuo78520a 已提交
289
    allidx = srcc.find("__all__")
290 291 292 293 294 295
    srcfile_new = srcfile.name
    srcfile_new = srcfile_new.replace('.py', '')
    srcfile_list = srcfile_new.split('/')
    srcfile_str = ''
    for i in range(4, len(srcfile_list)):
        srcfile_str = srcfile_str + srcfile_list[i] + '.'
296
    if allidx != -1:
T
tianshuo78520a 已提交
297
        alllist = []
298 299
        # get all list for layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
300
            for ai in range(0, len(srcls)):
301
                if srcls[ai].startswith("__all__"):
T
tianshuo78520a 已提交
302 303
                    lb = srcls[ai].find('[')
                    rb = srcls[ai].find(']')
304
                    if lb == -1:
T
tianshuo78520a 已提交
305 306 307 308
                        continue
                    allele = srcls[ai][lb + 1:rb].replace("'", '').replace(
                        " ", '').replace("\"", '')
                    alllist.append(allele)
309 310
            if '' in alllist:
                alllist.remove('')
T
tianshuo78520a 已提交
311 312 313 314 315 316 317 318 319 320 321 322
        else:
            alllist_b = allidx + len("__all__")
            allstr = srcc[alllist_b + srcc[alllist_b:].find("[") + 1:alllist_b +
                          srcc[alllist_b:].find("]")]
            allstr = allstr.replace("\n", '').replace(" ", '').replace(
                "'", '').replace("\"", '')
            alllist = allstr.split(',')
            if '' in alllist:
                alllist.remove('')
        api_alllist_count = len(alllist)
        api_count = 0
        handled = []
323 324
        # get src contents in layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
325 326 327 328 329 330 331
            for i in range(0, len(srcls)):
                if srcls[i].find("__doc__") != -1:
                    opname = srcls[i][:srcls[i].find("__doc__") - 1]
                    if opname in wlist:
                        continue
                    comstart = i
                    for j in range(i, len(srcls)):
332
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
333 334 335 336
                            comstart = i
                    opcom = ""
                    for j in range(comstart + 1, len(srcls)):
                        opcom += srcls[j]
337
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
338 339
                            break
                    api_count += 1
340
                    handled.append(
341 342 343
                        opname)  # ops.py also has normal formatted functions
                    # use list 'handled'  to mark the functions have been handled here
                    # which will be ignored in the following step
T
tianshuo78520a 已提交
344
        for i in range(0, len(srcls)):
345
            if srcls[i].startswith(
346
                    'def '):  # a function header is detected in line i
T
tianshuo78520a 已提交
347
                f_header = srcls[i].replace(" ", '')
348
                fn = f_header[len('def'):f_header.find('(')]  # function name
349 350
                if "%s%s" % (srcfile_str, fn) not in methods:
                    continue
T
tianshuo78520a 已提交
351 352 353 354
                if fn in handled:
                    continue
                if fn in alllist:
                    api_count += 1
355
                    if fn in wlist or fn + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
356 357
                        continue
                    fcombody = single_defcom_extract(i, srcls)
358 359 360 361
                    if fcombody == "":  # if no comment
                        print_header("def", fn)
                        print("WARNING: no comments in function ", fn,
                              ", but it deserves.")
T
tianshuo78520a 已提交
362 363
                        continue
                    else:
364 365
                        if not sampcd_extract_and_run(fcombody, fn, "def", fn):
                            process_result = False
366

T
tianshuo78520a 已提交
367 368
            if srcls[i].startswith('class '):
                c_header = srcls[i].replace(" ", '')
369
                cn = c_header[len('class'):c_header.find('(')]  # class name
370 371
                if '%s%s' % (srcfile_str, cn) not in methods:
                    continue
T
tianshuo78520a 已提交
372 373 374 375
                if cn in handled:
                    continue
                if cn in alllist:
                    api_count += 1
376
                    if cn in wlist or cn + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
377
                        continue
378
                    # class comment
T
tianshuo78520a 已提交
379
                    classcom = single_defcom_extract(i, srcls, True)
380
                    if classcom != "":
381 382
                        if not sampcd_extract_and_run(classcom, cn, "class",
                                                      cn):
383

384
                            process_result = False
T
tianshuo78520a 已提交
385
                    else:
386 387 388
                        print("WARNING: no comments in class itself ", cn,
                              ", but it deserves.\n")
                    # handling methods in class bodies
T
tianshuo78520a 已提交
389 390
                    for x in range(
                            i + 1,
391
                            len(srcls)):  # from the next line of class header
T
tianshuo78520a 已提交
392 393 394 395
                        if (srcls[x].startswith('def ') or
                                srcls[x].startswith('class ')):
                            break
                        else:
396
                            # member method def header
397
                            srcls[x] = srcls[x].replace('\t', '    ')
T
tianshuo78520a 已提交
398
                            if (srcls[x].startswith(
399
                                    '    def ')):  # detect a mehtod header..
T
tianshuo78520a 已提交
400 401 402
                                thisl = srcls[x]
                                indent = len(thisl) - len(thisl.lstrip())
                                mn = thisl[indent + len('def '):thisl.find(
403 404
                                    '(')]  # method name
                                name = cn + "." + mn  # full name
405 406 407 408
                                if '%s%s' % (
                                        srcfile_str, name
                                ) not in methods:  # class method not in api.spec 
                                    continue
T
tianshuo78520a 已提交
409 410
                                if mn.startswith('_'):
                                    continue
411
                                if name in wlist or name + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
412
                                    continue
413 414 415 416 417
                                thismethod = [thisl[indent:]
                                              ]  # method body lines
                                # get all the lines of a single method body
                                # into thismethod(list)
                                # and send it to single_defcom_extract
T
tianshuo78520a 已提交
418
                                for y in range(x + 1, len(srcls)):
419
                                    srcls[y] = srcls[y].replace('\t', '    ')
T
tianshuo78520a 已提交
420 421
                                    if (srcls[y].startswith('def ') or
                                            srcls[y].startswith('class ')):
422
                                        # end of method
T
tianshuo78520a 已提交
423
                                        break
424 425
                                    elif srcls[y].startswith('    def '):
                                        # end of method
T
tianshuo78520a 已提交
426 427 428 429 430
                                        break
                                    else:
                                        thismethod.append(srcls[y][indent:])
                                thismtdcom = single_defcom_extract(0,
                                                                   thismethod)
431
                                if thismtdcom != "":
432 433 434
                                    if not sampcd_extract_and_run(
                                            thismtdcom, name, "method", name):
                                        process_result = False
435

436
    return process_result
T
tianshuo78520a 已提交
437 438


439
def test(file_list):
440
    process_result = True
441
    for file in file_list:
442 443 444 445
        with open(file, 'r') as src:
            if not srccoms_extract(src, wlist):
                process_result = False
    return process_result
446 447


448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
def get_filenames(path):
    '''
    Given a path ``path``, this function will
    get the modules that pending for check.

    Args:
        path(path): the path of API.spec

    Returns:

        list: the modules pending for check .

    '''
    filenames = []
    global methods
    methods = []
    API_spec = '%s/%s' % (os.path.abspath(os.path.join(os.getcwd(), "..")),
                          path)
    with open(API_spec) as f:
        for line in f.readlines():
            api = line.split(' ', 1)[0]
            try:
                module = eval(api).__module__
            except AttributeError:
                continue
            if len(module.split('.')) > 2:
                filename = '../python/'
                module_py = '%s.py' % module.split('.')[-1]
                for i in range(0, len(module.split('.')) - 1):
                    filename = filename + '%s/' % module.split('.')[i]
                filename = filename + module_py
            else:
                print("\n----Exception in get api filename----\n")
                print("\n" + api + 'module is ' + module + "\n")
            if filename not in filenames:
                filenames.append(filename)
            # get all methods
            method = ''
            if inspect.isclass(eval(api)):
                name = api.split('.')[-1]
            elif inspect.isfunction(eval(api)):
                name = api.split('.')[-1]
            elif inspect.ismethod(eval(api)):
                name = '%s.%s' % (api.split('.')[-2], api.split('.')[-1])
            else:
                name = ''
                print("\n----Exception in get api methods----\n")
                print("\n" + line + "\n")
                print("\n" + api + ' method is None!!!' + "\n")
            for j in range(2, len(module.split('.'))):
                method = method + '%s.' % module.split('.')[j]
            method = method + name
            if method not in methods:
                methods.append(method)
    return filenames


505 506 507 508 509 510 511 512
'''
Important constant lists:

    wlist : a list of API that should not trigger the example check .
            It is composed of wlist_temp + wlist_inneed + wlist_ignore.
    srcfile: the source .py code file
'''

T
tianshuo78520a 已提交
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
wlist_inneed = [
    "append_LARS", "BuildStrategy.debug_graphviz_path",
    "BuildStrategy.enable_sequential_execution",
    "BuildStrategy.fuse_elewise_add_act_ops",
    "BuildStrategy.fuse_relu_depthwise_conv",
    "BuildStrategy.gradient_scale_strategy", "BuildStrategy.reduce_strategy",
    "BuildStrategy.remove_unnecessary_lock", "BuildStrategy.sync_batch_norm",
    "DynamicRNN.step_input", "DynamicRNN.static_input", "DynamicRNN.block",
    "DynamicRNN.update_memory", "DynamicRNN.output",
    "transpiler.DistributeTranspilerConfig",
    "transpiler.DistributeTranspilerConfig.slice_var_up",
    "transpiler.DistributeTranspilerConfig.split_method",
    "transpiler.DistributeTranspilerConfig.min_block_size",
    "DistributeTranspilerConfig.slice_var_up",
    "DistributeTranspilerConfig.split_method", "ModelAverage.apply",
    "ModelAverage.restore", "DistributeTranspilerConfig",
    "DistributeTranspilerConfig.min_block_size",
    "ExecutionStrategy.allow_op_delay", "load", "Accuracy.update",
    "ChunkEvaluator.update", "ExecutionStrategy.num_iteration_per_drop_scope",
F
flame 已提交
532
    "ExecutionStrategy.num_threads", "CompiledProgram._with_inference_optimize",
T
tianshuo78520a 已提交
533 534 535 536 537 538 539 540
    "CompositeMetric.add_metric", "CompositeMetric.update",
    "CompositeMetric.eval", "DetectionMAP.get_map_var", "MetricBase",
    "MetricBase.reset", "MetricBase.get_config", "MetricBase.update",
    "MetricBase.eval", "Accuracy.eval", "Auc.update", "Auc.eval",
    "EditDistance.update", "EditDistance.eval",
    "ExponentialMovingAverage.apply", "ExponentialMovingAverage.restore",
    "ExponentialMovingAverage.update", "StaticRNN.step", "StaticRNN.step_input",
    "StaticRNN.step_output", "StaticRNN.update_memory", "DetectionMAP.reset",
541 542
    'StaticRNN.output', "cuda_places", "CUDAPinnedPlace", "CUDAPlace",
    "Program.parse_from_string"
T
tianshuo78520a 已提交
543
]
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595

wlist_nosample = [
    'Compressor', 'Compressor.config', 'Compressor.run', 'run_check',
    'HDFSClient.upload', 'HDFSClient.download', 'HDFSClient.is_exist',
    'HDFSClient.is_dir', 'HDFSClient.delete', 'HDFSClient.rename',
    'HDFSClient.makedirs', 'HDFSClient.ls', 'HDFSClient.lsr', 'multi_download',
    'multi_upload', 'TrainingDecoder.block',
    'QuantizeTranspiler.training_transpile',
    'QuantizeTranspiler.freeze_program', 'AutoMixedPrecisionLists',
    'Uniform.sample', 'Uniform.log_prob', 'Uniform.entropy',
    'Categorical.kl_divergence', 'Categorical.entropy',
    'MultivariateNormalDiag.entropy', 'MultivariateNormalDiag.kl_divergence',
    'RNNCell', 'RNNCell.call', 'RNNCell.get_initial_states', 'GRUCell.call',
    'LSTMCell.call', 'Decoder', 'Decoder.initialize', 'Decoder.step',
    'Decoder.finalize', 'fused_elemwise_activation', 'search_pyramid_hash',
    'convert_dist_to_sparse_program', 'load_persistables_for_increment',
    'load_persistables_for_inference', 'cache', 'buffered', 'xmap_readers'
]

wlist_no_op_pass = ['gelu', 'erf']

wlist_ci_nopass = [
    'DecodeHelper', 'DecodeHelper.initialize', 'DecodeHelper.sample',
    'DecodeHelper.next_inputs', 'TrainingHelper.initialize',
    'TrainingHelper.sample', 'TrainingHelper.next_inputs',
    'GreedyEmbeddingHelper.initialize', 'GreedyEmbeddingHelper.sample',
    'GreedyEmbeddingHelper.next_inputs', 'LayerList.append', 'HDFSClient',
    'InitState', 'TracedLayer', 'SampleEmbeddingHelper.sample',
    'BasicDecoder.initialize', 'BasicDecoder.step', 'ParameterList.append',
    'GreedyEmbeddingHelper', 'SampleEmbeddingHelper', 'BasicDecoder', 'lstm',
    'partial_sum'
]

wlist_nopass = [
    'StateCell', 'StateCell.compute_state', 'TrainingDecoder',
    'TrainingDecoder.step_input', 'TrainingDecoder.static_input',
    'TrainingDecoder.output', 'BeamSearchDecoder', 'GradClipByValue',
    'GradClipByNorm', 'Variable.detach', 'Variable.numpy', 'Variable.set_value',
    'Variable.gradient', 'BeamSearchDecoder.decode',
    'BeamSearchDecoder.read_array', 'CompiledProgram',
    'CompiledProgram.with_data_parallel', 'append_backward', 'guard',
    'to_variable', 'op_freq_statistic', 'save_dygraph', 'load_dygraph',
    'ParallelExecutor', 'ParallelExecutor.run',
    'ParallelExecutor.drop_local_exe_scopes', 'GradClipByGlobalNorm',
    'extend_with_decoupled_weight_decay', 'switch', 'Normal', 'memory_usage',
    'decorate', 'PiecewiseDecay', 'InverseTimeDecay', 'PolynomialDecay',
    'NoamDecay', 'start_profiler', 'profiler', 'tree_conv', 'multiclass_nms2',
    'DataFeedDesc', 'Conv2D', 'Conv3D', 'Conv3DTranspose', 'Embedding', 'NCE',
    'PRelu', 'BilinearTensorProduct', 'GroupNorm', 'SpectralNorm', 'TreeConv',
    'prroi_pool'
]

T
tianshuo78520a 已提交
596
wlist_temp = [
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
    'ChunkEvaluator',
    'EditDistance',
    'ErrorClipByValue',
    'Program.clone',
    'cuda_pinned_places',
    'DataFeeder',
    'elementwise_floordiv',
    'Layer',
    'Layer.create_parameter',
    'Layer.create_variable',
    'Layer.sublayers',
    'Layer.add_parameter',
    'Layer.add_sublayer',
    'Layer.parameters',
    'Tracer',
    'Layer.full_name',
    'InMemoryDataset',
    'layer_norm',
    'bipartite_match',
    'double_buffer',
    'cumsum',
    'thresholded_relu',
    'group_norm',
    'random_crop',
    'py_func',
    'row_conv',
    'hard_shrink',
    'ssd_loss',
    'retinanet_target_assign',
    'InMemoryDataset.global_shuffle',
    'InMemoryDataset.get_memory_data_size',
    'DetectionMAP',
    'hash',
    'InMemoryDataset.set_queue_num',
    'LayerNorm',
    'Preprocessor',
    'chunk_eval',
    'GRUUnit',
    'ExponentialMovingAverage',
    'QueueDataset.global_shuffle',
    'NumpyArrayInitializer',
    'create_py_reader_by_data',
    'InMemoryDataset.local_shuffle',
    'InMemoryDataset.get_shuffle_data_size',
    'size',
    'edit_distance',
    'nce',
    'BilinearInitializer',
    'NaturalExpDecay',
    'noam_decay',
    'retinanet_detection_output',
    'Pool2D',
    'PipelineOptimizer',
    'generate_mask_labels',
    'isfinite',
    'InMemoryDataset.set_fleet_send_batch_size',
    'cuda_profiler',
    'unfold',
    'Executor',
    'InMemoryDataset.load_into_memory',
    'ExponentialDecay',
    'BatchNorm',
    'deformable_conv',
    'InMemoryDataset.preload_into_memory',
    'py_reader',
    'linear_lr_warmup',
    'InMemoryDataset.wait_preload_done',
    'CosineDecay',
    'roi_perspective_transform',
    'unique',
    'ones_like',
    'LambOptimizer',
    'InMemoryDataset.release_memory',
    'Conv2DTranspose',
    'QueueDataset.local_shuffle',
    # wrong in dygraph/checkpoint.py  ok in io.py [duplicated name]
    'save_persistables@dygraph/checkpoint.py',
    'load_persistables@dygraph/checkpoint.py'
T
tianshuo78520a 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
]
'''
white list of private API/ redundant API
'''
wlist_ignore = [
    'elementwise_pow', 'WeightedAverage.reset', 'ChunkEvaluator.eval',
    'NCE.forward', 'elementwise_div', 'BilinearTensorProduct.forward',
    'NoamDecay.step', 'elementwise_min', 'PiecewiseDecay.step',
    'Conv3DTranspose.forward', 'elementwise_add', 'IfElse.output',
    'IfElse.true_block', 'InverseTimeDecay.step', 'PolynomialDecay.step',
    'Precision.eval', 'enabled', 'elementwise_max', 'stop_gperf_profiler',
    'IfElse.false_block', 'WeightedAverage.add', 'Auc.trapezoid_area',
    'elementwise_mul', 'GroupNorm.forward', 'SpectralNorm.forward',
    'elementwise_sub', 'Switch.case', 'IfElse.input', 'prepare_context',
    'PRelu.forward', 'Recall.update', 'start_gperf_profiler',
    'TreeConv.forward', 'Conv2D.forward', 'Switch.default', 'elementwise_mod',
    'Precision.update', 'WeightedAverage.eval', 'Conv3D.forward',
692 693
    'Embedding.forward', 'Recall.eval', 'FC.forward', 'While.block',
    'DGCMomentumOptimizer'
T
tianshuo78520a 已提交
694
]
695 696 697
# only white on CPU
gpu_not_white = [
    "deformable_conv", "cuda_places", "CUDAPinnedPlace", "CUDAPlace",
698
    "cuda_profiler", 'DGCMomentumOptimizer'
699
]
700 701

wlist = wlist_temp + wlist_inneed + wlist_ignore + wlist_nosample + wlist_nopass + wlist_no_op_pass + wlist_ci_nopass
702 703

if len(sys.argv) < 2:
704
    print("Error: inadequate number of arguments")
705 706 707 708 709
    print('''If you are going to run it on 
        "CPU: >>> python sampcd_processor.py cpu
        "GPU: >>> python sampcd_processor.py gpu
        ''')
    sys.exit("lack arguments")
T
tianshuo78520a 已提交
710
else:
711 712 713 714
    if sys.argv[1] == "gpu":
        for _gnw in gpu_not_white:
            wlist.remove(_gnw)
    elif sys.argv[1] != "cpu":
715 716
        print("Unrecognized argument:'", sys.argv[1], "' , 'cpu' or 'gpu' is ",
              "desired\n")
717
        sys.exit("Invalid arguments")
718 719
    print("API check -- Example Code")
    print("sample_test running under python", platform.python_version())
720 721
    if not os.path.isdir("./samplecode_temp"):
        os.mkdir("./samplecode_temp")
722
    cpus = multiprocessing.cpu_count()
723 724
    filenames = get_filenames('paddle/fluid/API_PR.spec')
    filenames.remove('../python/paddle/fluid/core_avx.py')
725
    one_part_filenum = int(math.ceil(len(filenames) / cpus))
726 727 728 729
    divided_file_list = [
        filenames[i:i + one_part_filenum]
        for i in range(0, len(filenames), one_part_filenum)
    ]
730

731 732
    po = multiprocessing.Pool()
    results = po.map_async(test, divided_file_list)
733 734
    po.close()
    po.join()
735

736
    result = results.get()
737

738
    # delete temp files
739 740 741 742 743
    for root, dirs, files in os.walk("./samplecode_temp"):
        for fntemp in files:
            os.remove("./samplecode_temp/" + fntemp)
    os.rmdir("./samplecode_temp")

744
    print("----------------End of the Check--------------------")
745 746 747 748 749
    for temp in result:
        if not temp:
            print("Mistakes found in sample codes")
            exit(1)
    print("Sample code check is successful!")