sampcd_processor.py 24.5 KB
Newer Older
1
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
T
tianshuo78520a 已提交
14 15

import os
16
import sys
T
tianshuo78520a 已提交
17
import subprocess
18 19 20 21 22
import multiprocessing
import math
import platform
"""
please make sure to run in the tools path
23
usage: python sample_test.py {arg1} 
24 25 26
arg1: the first arg defined running in gpu version or cpu version

for example, you can run cpu version python2 testing like this:
27 28 29

    python sampcd_processor.py cpu 

30
"""
T
tianshuo78520a 已提交
31 32 33


def find_all(srcstr, substr):
34
    """
35 36 37 38 39 40
    to find all desired substring in the source string
     and return their starting indices as a list

    Args:
        srcstr(str): the parent string
        substr(str): substr
41

42
    Returns:
43
        list: a list of the indices of the substrings
44
              found
45
    """
T
tianshuo78520a 已提交
46 47 48 49 50 51 52 53 54
    indices = []
    gotone = srcstr.find(substr)
    while (gotone != -1):
        indices.append(gotone)
        gotone = srcstr.find(substr, gotone + 1)
    return indices


def check_indent(cdline):
55
    """
56
    to check the indent of a given code line
57

58 59
    to get the number of starting blank chars,
    e.t. blankspaces and \t
60 61

    \t will be interpreted as 4 single blankspaces,
62
    e.t. '\t'='    '
63

64 65 66 67
    Args:
        cdline(str) : a single line of code from the source file

    Returns:
68
        int : the indent of the number of interpreted
69
             blankspaces
70
    """
T
tianshuo78520a 已提交
71 72 73 74 75 76 77 78 79 80 81
    indent = 0
    for c in cdline:
        if c == '\t':
            indent += 4
        elif c == ' ':
            indent += 1
        if c != ' ' and c != '\t':
            break
    return indent


82 83 84
# srccom: raw comments in the source,including ''' and original indent
def sampcd_extract_and_run(srccom, name, htype="def", hname=""):
    """
85 86 87 88 89 90 91 92 93
    Extract and run sample codes from source comment and
    the result will be returned.

    Args:
        srccom(str): the source comment of some API whose
                     example codes will be extracted and run.
        name(str): the name of the API.
        htype(str): the type of hint banners, def/class/method.
        hname(str): the name of the hint  banners , e.t. def hname.
94

95
    Returns:
96
        result: True or False
97 98
    """

99 100
    result = True

101 102
    def sampcd_header_print(name, sampcd, htype, hname):
        """
103
        print hint banner headers.
104

105 106 107 108 109 110
        Args:
            name(str): the name of the API.
            sampcd(str): sample code string
            htype(str): the type of hint banners, def/class/method.
            hname(str): the name of the hint  banners , e.t. def hname.
            flushed.
111 112 113
        """
        print_header(htype, hname)
        print("Sample code ", str(y), " extracted for ", name, "   :")
114
        print(sampcd)
115 116 117
        print("----example code check----\n")
        print("executing sample code .....")
        print("execution result:")
118 119

    sampcd_begins = find_all(srccom, " code-block:: python")
120 121
    if len(sampcd_begins) == 0:
        print_header(htype, hname)
122 123 124 125
        '''
        detect sample codes using >>> to format
        and consider this situation as wrong
        '''
126 127 128
        if srccom.find("Examples:") != -1:
            print("----example code check----\n")
            if srccom.find(">>>") != -1:
T
tianshuo78520a 已提交
129
                print(
130 131
                    "Deprecated sample code style:\n\n    Examples:\n\n        >>>codeline\n        >>>codeline\n\n\n ",
                    "Please use '.. code-block:: python' to ",
T
tianshuo78520a 已提交
132
                    "format sample code.\n")
133
                result = False
T
tianshuo78520a 已提交
134
        else:
135 136
            print("Error: No sample code!\n")
            result = False
T
tianshuo78520a 已提交
137 138 139

    for y in range(1, len(sampcd_begins) + 1):
        sampcd_begin = sampcd_begins[y - 1]
140
        sampcd = srccom[sampcd_begin + len(" code-block:: python") + 1:]
T
tianshuo78520a 已提交
141
        sampcd = sampcd.split("\n")
142
        # remove starting empty lines
T
tianshuo78520a 已提交
143 144
        while sampcd[0].replace(' ', '').replace('\t', '') == '':
            sampcd.pop(0)
145

146 147
        # the minimum indent, which is the indent of the first
        # non-empty line
T
tianshuo78520a 已提交
148 149 150 151
        min_indent = check_indent(sampcd[0])
        sampcd_to_write = []
        for i in range(0, len(sampcd)):
            cdline = sampcd[i]
152
            # handle empty lines or those only with spaces/tabs
T
tianshuo78520a 已提交
153 154 155
            if cdline.strip() == '':
                continue
            this_indent = check_indent(cdline)
156
            if this_indent < min_indent:
T
tianshuo78520a 已提交
157 158 159 160
                break
            else:
                cdline = cdline.replace('\t', '    ')
                sampcd_to_write.append(cdline[min_indent:])
161

T
tianshuo78520a 已提交
162
        sampcd = '\n'.join(sampcd_to_write)
163 164 165 166
        if sys.argv[1] == "cpu":
            sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = ""\n' + sampcd
        if sys.argv[1] == "gpu":
            sampcd = '\nimport os\n' + 'os.environ["CUDA_VISIBLE_DEVICES"] = "0"\n' + sampcd
167
        sampcd += '\nprint(' + '\"' + name + ' sample code is executed successfully!\")'
T
tianshuo78520a 已提交
168

169
        if len(sampcd_begins) > 1:
T
tianshuo78520a 已提交
170 171 172 173 174 175
            tfname = name + "_example_" + str(y) + ".py"
        else:
            tfname = name + "_example" + ".py"
        tempf = open("samplecode_temp/" + tfname, 'w')
        tempf.write(sampcd)
        tempf.close()
176 177 178 179 180
        if platform.python_version()[0] == "2":
            cmd = ["python", "samplecode_temp/" + tfname]
        elif platform.python_version()[0] == "3":
            cmd = ["python3", "samplecode_temp/" + tfname]
        else:
181 182
            print("Error: fail to parse python version!")
            result = False
183
            exit(1)
184

T
tianshuo78520a 已提交
185 186
        subprc = subprocess.Popen(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
187
        output, error = subprc.communicate()
188 189 190 191 192 193 194 195 196 197
        msg = "".join(output.decode(encoding='utf-8'))
        err = "".join(error.decode(encoding='utf-8'))

        if subprc.returncode != 0:
            print("\nSample code error found in ", name, ":\n")
            sampcd_header_print(name, sampcd, htype, hname)
            print("subprocess return code: ", str(subprc.returncode))
            print("Error Raised from Sample Code ", name, " :\n")
            print(err)
            print(msg)
198
            result = False
199
        # msg is the returned code execution report
T
tianshuo78520a 已提交
200
        os.remove("samplecode_temp/" + tfname)
201
    return result
T
tianshuo78520a 已提交
202 203 204


def single_defcom_extract(start_from, srcls, is_class_begin=False):
205
    """
206 207
    to extract a def function/class/method comments body

208
    Args:
209 210 211 212 213 214
        start_from(int): the line num of "def" header
        srcls(list): the source file in lines
        is_class_begin(bool): whether the start_from is a beginning a class. \
        For a sole class body itself may end up with its method if it has no
        docstring. But the body of \
        a common def function can only be ended up by a none-indented def/class
215

216 217 218
    Returns:
        string : the extracted comment body, inclusive of its quote marks.

219
    """
220

T
tianshuo78520a 已提交
221
    i = start_from
222 223 224
    fcombody = ""  # def comment body
    comstart = -1  # the starting line index of comment mark "'''" or """"""
    # if it is not -1, it indicates the loop is in the comment body
225 226
    comstyle = 0  # comment mark style ,comments quoted with ''' is coded as 1
    # comments quoted with """ is coded as 2
T
tianshuo78520a 已提交
227 228
    for x in range(i + 1, len(srcls)):
        if is_class_begin:
229
            if srcls[x].replace('\t', '    ').startswith('    def '):
T
tianshuo78520a 已提交
230
                break
231
        if srcls[x].startswith('def ') or srcls[x].startswith('class '):
T
tianshuo78520a 已提交
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
            break
        else:
            if (comstart == -1 and srcls[x].replace(" ", '').replace(
                    "\t", '').replace("\n", '').startswith("\"\"\"")):
                comstart = x
                comstyle = 2
                continue
            if (comstyle == 2 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\"\"\"")):
                break
            if (comstart == -1 and srcls[x].replace(" ", '').replace(
                    "\t", '').replace("\n", '').startswith("\'\'\'")):
                comstart = x
                comstyle = 1
                continue
            if (comstyle == 1 and comstart != -1 and
                    srcls[x].replace(" ", '').replace("\t", '').replace(
                        "\n", '').startswith("\'\'\'")):
                break
            if (comstart !=
253
                    -1):  # when the comments start, begin to add line to fcombody
T
tianshuo78520a 已提交
254 255 256 257
                fcombody += srcls[x]
    return fcombody


258 259 260
def print_header(htype, name):
    print(htype, " name:", name)
    print("-----------------------")
261

T
tianshuo78520a 已提交
262

263
def srccoms_extract(srcfile, wlist):
264
    """
265 266 267 268 269 270 271 272 273
    Given a source file ``srcfile``, this function will
    extract its API(doc comments) and run sample codes in the
    API.

    Args:
        srcfile(file): the source file
        wlist(list): white list

    Returns:
274
    result: True or False
275
    """
276

277
    process_result = True
T
tianshuo78520a 已提交
278 279
    srcc = srcfile.read()

280 281
    # 2. get defs and classes header line number
    # set file pointer to its beginning
T
tianshuo78520a 已提交
282
    srcfile.seek(0, 0)
283
    srcls = srcfile.readlines()  # source lines
284

285
    # 1. fetch__all__ list
T
tianshuo78520a 已提交
286 287
    allidx = srcc.find("__all__")

288
    if allidx != -1:
T
tianshuo78520a 已提交
289
        alllist = []
290 291
        # get all list for layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
292
            for ai in range(0, len(srcls)):
293
                if srcls[ai].startswith("__all__"):
T
tianshuo78520a 已提交
294 295
                    lb = srcls[ai].find('[')
                    rb = srcls[ai].find(']')
296
                    if lb == -1:
T
tianshuo78520a 已提交
297 298 299 300
                        continue
                    allele = srcls[ai][lb + 1:rb].replace("'", '').replace(
                        " ", '').replace("\"", '')
                    alllist.append(allele)
301 302
            if '' in alllist:
                alllist.remove('')
T
tianshuo78520a 已提交
303 304 305 306 307 308 309 310 311 312 313 314
        else:
            alllist_b = allidx + len("__all__")
            allstr = srcc[alllist_b + srcc[alllist_b:].find("[") + 1:alllist_b +
                          srcc[alllist_b:].find("]")]
            allstr = allstr.replace("\n", '').replace(" ", '').replace(
                "'", '').replace("\"", '')
            alllist = allstr.split(',')
            if '' in alllist:
                alllist.remove('')
        api_alllist_count = len(alllist)
        api_count = 0
        handled = []
315

316 317
        # get src contents in layers/ops.py
        if srcfile.name.find("ops.py") != -1:
T
tianshuo78520a 已提交
318 319 320 321 322 323 324
            for i in range(0, len(srcls)):
                if srcls[i].find("__doc__") != -1:
                    opname = srcls[i][:srcls[i].find("__doc__") - 1]
                    if opname in wlist:
                        continue
                    comstart = i
                    for j in range(i, len(srcls)):
325
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
326 327 328 329
                            comstart = i
                    opcom = ""
                    for j in range(comstart + 1, len(srcls)):
                        opcom += srcls[j]
330
                        if srcls[j].find("\"\"\"") != -1:
T
tianshuo78520a 已提交
331
                            break
332 333
                    process_result = sampcd_extract_and_run(opcom, opname,
                                                            "def", opname)
T
tianshuo78520a 已提交
334
                    api_count += 1
335
                    handled.append(
336 337 338
                        opname)  # ops.py also has normal formatted functions
                    # use list 'handled'  to mark the functions have been handled here
                    # which will be ignored in the following step
T
tianshuo78520a 已提交
339 340

        for i in range(0, len(srcls)):
341
            if srcls[i].startswith(
342
                    'def '):  # a function header is detected in line i
T
tianshuo78520a 已提交
343
                f_header = srcls[i].replace(" ", '')
344
                fn = f_header[len('def'):f_header.find('(')]  # function name
T
tianshuo78520a 已提交
345 346 347 348
                if fn in handled:
                    continue
                if fn in alllist:
                    api_count += 1
349
                    if fn in wlist or fn + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
350 351
                        continue
                    fcombody = single_defcom_extract(i, srcls)
352 353 354 355
                    if fcombody == "":  # if no comment
                        print_header("def", fn)
                        print("WARNING: no comments in function ", fn,
                              ", but it deserves.")
T
tianshuo78520a 已提交
356 357
                        continue
                    else:
358 359
                        if not sampcd_extract_and_run(fcombody, fn, "def", fn):
                            process_result = False
T
tianshuo78520a 已提交
360 361
            if srcls[i].startswith('class '):
                c_header = srcls[i].replace(" ", '')
362
                cn = c_header[len('class'):c_header.find('(')]  # class name
T
tianshuo78520a 已提交
363 364 365 366
                if cn in handled:
                    continue
                if cn in alllist:
                    api_count += 1
367
                    if cn in wlist or cn + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
368
                        continue
369
                    # class comment
T
tianshuo78520a 已提交
370
                    classcom = single_defcom_extract(i, srcls, True)
371
                    if classcom != "":
372 373 374
                        if not sampcd_extract_and_run(classcom, cn, "class",
                                                      cn):
                            process_result = False
T
tianshuo78520a 已提交
375
                    else:
376 377 378
                        print("WARNING: no comments in class itself ", cn,
                              ", but it deserves.\n")
                    # handling methods in class bodies
T
tianshuo78520a 已提交
379 380
                    for x in range(
                            i + 1,
381
                            len(srcls)):  # from the next line of class header
T
tianshuo78520a 已提交
382 383 384 385
                        if (srcls[x].startswith('def ') or
                                srcls[x].startswith('class ')):
                            break
                        else:
386
                            # member method def header
387
                            srcls[x] = srcls[x].replace('\t', '    ')
T
tianshuo78520a 已提交
388
                            if (srcls[x].startswith(
389
                                    '    def ')):  # detect a mehtod header..
T
tianshuo78520a 已提交
390 391 392
                                thisl = srcls[x]
                                indent = len(thisl) - len(thisl.lstrip())
                                mn = thisl[indent + len('def '):thisl.find(
393 394
                                    '(')]  # method name
                                name = cn + "." + mn  # full name
T
tianshuo78520a 已提交
395 396
                                if mn.startswith('_'):
                                    continue
397
                                if name in wlist or name + "@" + srcfile.name in wlist:
T
tianshuo78520a 已提交
398
                                    continue
399 400 401 402 403
                                thismethod = [thisl[indent:]
                                              ]  # method body lines
                                # get all the lines of a single method body
                                # into thismethod(list)
                                # and send it to single_defcom_extract
T
tianshuo78520a 已提交
404
                                for y in range(x + 1, len(srcls)):
405
                                    srcls[y] = srcls[y].replace('\t', '    ')
T
tianshuo78520a 已提交
406 407
                                    if (srcls[y].startswith('def ') or
                                            srcls[y].startswith('class ')):
408
                                        # end of method
T
tianshuo78520a 已提交
409
                                        break
410 411
                                    elif srcls[y].startswith('    def '):
                                        # end of method
T
tianshuo78520a 已提交
412 413 414 415 416
                                        break
                                    else:
                                        thismethod.append(srcls[y][indent:])
                                thismtdcom = single_defcom_extract(0,
                                                                   thismethod)
417
                                if thismtdcom != "":
418 419 420 421
                                    if not sampcd_extract_and_run(
                                            thismtdcom, name, "method", name):
                                        process_result = False
    return process_result
T
tianshuo78520a 已提交
422 423


424
def test(file_list):
425
    process_result = True
426
    for file in file_list:
427 428 429 430
        with open(file, 'r') as src:
            if not srccoms_extract(src, wlist):
                process_result = False
    return process_result
431 432


433 434 435 436 437 438 439 440 441
'''
Important constant lists:

    filenames : the modules pending for check .
    wlist : a list of API that should not trigger the example check .
            It is composed of wlist_temp + wlist_inneed + wlist_ignore.
    srcfile: the source .py code file
'''

T
tianshuo78520a 已提交
442
filenames = [
443 444 445 446 447 448 449 450
    "../python/paddle/fluid/layers/control_flow.py",
    "../python/paddle/fluid/layers/io.py",
    "../python/paddle/fluid/layers/nn.py",
    "../python/paddle/fluid/layers/ops.py",
    "../python/paddle/fluid/layers/tensor.py",
    "../python/paddle/fluid/layers/learning_rate_scheduler.py",
    "../python/paddle/fluid/layers/detection.py",
    "../python/paddle/fluid/layers/metric_op.py"
T
tianshuo78520a 已提交
451 452
]
filenames += [
453 454 455 456 457 458 459 460 461
    "../python/paddle/fluid/dygraph/layers.py",
    "../python/paddle/fluid/dygraph/base.py",
    "../python/paddle/fluid/dygraph/nn.py",
    "../python/paddle/fluid/dygraph/tracer.py",
    "../python/paddle/fluid/dygraph/profiler.py",
    "../python/paddle/fluid/dygraph/parallel.py",
    "../python/paddle/fluid/dygraph/checkpoint.py",
    "../python/paddle/fluid/dygraph/learning_rate_scheduler.py",
    "../python/paddle/fluid/dygraph/backward_strategy.py"
T
tianshuo78520a 已提交
462 463
]
filenames += [
464 465 466 467 468 469 470 471 472 473 474 475
    "../python/paddle/fluid/data_feeder.py",
    "../python/paddle/fluid/dataset.py", "../python/paddle/fluid/clip.py",
    "../python/paddle/fluid/metrics.py", "../python/paddle/fluid/executor.py",
    "../python/paddle/fluid/initializer.py", "../python/paddle/fluid/io.py",
    "../python/paddle/fluid/nets.py", "../python/paddle/fluid/optimizer.py",
    "../python/paddle/fluid/profiler.py",
    "../python/paddle/fluid/regularizer.py",
    "../python/paddle/fluid/backward.py", "../python/paddle/fluid/average.py",
    "../python/paddle/fluid/unique_name.py",
    "../python/paddle/fluid/framework.py",
    "../python/paddle/fluid/evaluator.py",
    "../python/paddle/fluid/param_attr.py"
T
tianshuo78520a 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
]
wlist_inneed = [
    "append_LARS", "BuildStrategy.debug_graphviz_path",
    "BuildStrategy.enable_sequential_execution",
    "BuildStrategy.fuse_elewise_add_act_ops",
    "BuildStrategy.fuse_relu_depthwise_conv",
    "BuildStrategy.gradient_scale_strategy", "BuildStrategy.reduce_strategy",
    "BuildStrategy.remove_unnecessary_lock", "BuildStrategy.sync_batch_norm",
    "DynamicRNN.step_input", "DynamicRNN.static_input", "DynamicRNN.block",
    "DynamicRNN.update_memory", "DynamicRNN.output",
    "transpiler.DistributeTranspilerConfig",
    "transpiler.DistributeTranspilerConfig.slice_var_up",
    "transpiler.DistributeTranspilerConfig.split_method",
    "transpiler.DistributeTranspilerConfig.min_block_size",
    "DistributeTranspilerConfig.slice_var_up",
    "DistributeTranspilerConfig.split_method", "ModelAverage.apply",
    "ModelAverage.restore", "DistributeTranspilerConfig",
    "DistributeTranspilerConfig.min_block_size",
    "ExecutionStrategy.allow_op_delay", "load", "Accuracy.update",
    "ChunkEvaluator.update", "ExecutionStrategy.num_iteration_per_drop_scope",
F
flame 已提交
496
    "ExecutionStrategy.num_threads", "CompiledProgram._with_inference_optimize",
T
tianshuo78520a 已提交
497 498 499 500 501 502 503 504
    "CompositeMetric.add_metric", "CompositeMetric.update",
    "CompositeMetric.eval", "DetectionMAP.get_map_var", "MetricBase",
    "MetricBase.reset", "MetricBase.get_config", "MetricBase.update",
    "MetricBase.eval", "Accuracy.eval", "Auc.update", "Auc.eval",
    "EditDistance.update", "EditDistance.eval",
    "ExponentialMovingAverage.apply", "ExponentialMovingAverage.restore",
    "ExponentialMovingAverage.update", "StaticRNN.step", "StaticRNN.step_input",
    "StaticRNN.step_output", "StaticRNN.update_memory", "DetectionMAP.reset",
505 506
    'StaticRNN.output', "cuda_places", "CUDAPinnedPlace", "CUDAPlace",
    "Program.parse_from_string"
T
tianshuo78520a 已提交
507 508
]
wlist_temp = [
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
    'ChunkEvaluator',
    'EditDistance',
    'ErrorClipByValue',
    'Program.clone',
    'cuda_pinned_places',
    'DataFeeder',
    'elementwise_floordiv',
    'Layer',
    'Layer.create_parameter',
    'Layer.create_variable',
    'Layer.sublayers',
    'Layer.add_parameter',
    'Layer.add_sublayer',
    'Layer.parameters',
    'Tracer',
    'Layer.full_name',
    'InMemoryDataset',
    'layer_norm',
    'bipartite_match',
    'double_buffer',
    'cumsum',
    'thresholded_relu',
    'group_norm',
    'random_crop',
    'py_func',
    'row_conv',
    'hard_shrink',
    'ssd_loss',
    'retinanet_target_assign',
    'InMemoryDataset.global_shuffle',
    'InMemoryDataset.get_memory_data_size',
    'DetectionMAP',
    'hash',
    'InMemoryDataset.set_queue_num',
    'LayerNorm',
    'Preprocessor',
    'chunk_eval',
    'GRUUnit',
    'ExponentialMovingAverage',
    'QueueDataset.global_shuffle',
    'NumpyArrayInitializer',
    'create_py_reader_by_data',
    'InMemoryDataset.local_shuffle',
    'InMemoryDataset.get_shuffle_data_size',
    'size',
    'edit_distance',
    'nce',
    'BilinearInitializer',
    'NaturalExpDecay',
    'noam_decay',
    'retinanet_detection_output',
    'Pool2D',
    'PipelineOptimizer',
    'generate_mask_labels',
    'isfinite',
    'InMemoryDataset.set_fleet_send_batch_size',
    'cuda_profiler',
    'unfold',
    'Executor',
    'InMemoryDataset.load_into_memory',
    'ExponentialDecay',
    'BatchNorm',
    'deformable_conv',
    'InMemoryDataset.preload_into_memory',
    'py_reader',
    'linear_lr_warmup',
    'InMemoryDataset.wait_preload_done',
    'CosineDecay',
    'roi_perspective_transform',
    'unique',
    'ones_like',
    'LambOptimizer',
    'InMemoryDataset.release_memory',
    'Conv2DTranspose',
    'QueueDataset.local_shuffle',
    # wrong in dygraph/checkpoint.py  ok in io.py [duplicated name]
    'save_persistables@dygraph/checkpoint.py',
    'load_persistables@dygraph/checkpoint.py'
T
tianshuo78520a 已提交
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
]
'''
white list of private API/ redundant API
'''
wlist_ignore = [
    'elementwise_pow', 'WeightedAverage.reset', 'ChunkEvaluator.eval',
    'NCE.forward', 'elementwise_div', 'BilinearTensorProduct.forward',
    'NoamDecay.step', 'elementwise_min', 'PiecewiseDecay.step',
    'Conv3DTranspose.forward', 'elementwise_add', 'IfElse.output',
    'IfElse.true_block', 'InverseTimeDecay.step', 'PolynomialDecay.step',
    'Precision.eval', 'enabled', 'elementwise_max', 'stop_gperf_profiler',
    'IfElse.false_block', 'WeightedAverage.add', 'Auc.trapezoid_area',
    'elementwise_mul', 'GroupNorm.forward', 'SpectralNorm.forward',
    'elementwise_sub', 'Switch.case', 'IfElse.input', 'prepare_context',
    'PRelu.forward', 'Recall.update', 'start_gperf_profiler',
    'TreeConv.forward', 'Conv2D.forward', 'Switch.default', 'elementwise_mod',
    'Precision.update', 'WeightedAverage.eval', 'Conv3D.forward',
    'Embedding.forward', 'Recall.eval', 'FC.forward', 'While.block'
]
606 607 608 609 610
# only white on CPU
gpu_not_white = [
    "deformable_conv", "cuda_places", "CUDAPinnedPlace", "CUDAPlace",
    "cuda_profiler"
]
T
tianshuo78520a 已提交
611
wlist = wlist_temp + wlist_inneed + wlist_ignore
612 613

if len(sys.argv) < 2:
614
    print("Error: inadequate number of arguments")
615 616 617 618 619
    print('''If you are going to run it on 
        "CPU: >>> python sampcd_processor.py cpu
        "GPU: >>> python sampcd_processor.py gpu
        ''')
    sys.exit("lack arguments")
T
tianshuo78520a 已提交
620
else:
621 622 623 624
    if sys.argv[1] == "gpu":
        for _gnw in gpu_not_white:
            wlist.remove(_gnw)
    elif sys.argv[1] != "cpu":
625 626
        print("Unrecognized argument:'", sys.argv[1], "' , 'cpu' or 'gpu' is ",
              "desired\n")
627
        sys.exit("Invalid arguments")
628 629
    print("API check -- Example Code")
    print("sample_test running under python", platform.python_version())
630 631 632
    if not os.path.isdir("./samplecode_temp"):
        os.mkdir("./samplecode_temp")

633 634
    cpus = multiprocessing.cpu_count()
    one_part_filenum = int(math.ceil(len(filenames) / cpus))
635 636 637 638
    divided_file_list = [
        filenames[i:i + one_part_filenum]
        for i in range(0, len(filenames), one_part_filenum)
    ]
639 640
    po = multiprocessing.Pool()
    results = po.map_async(test, divided_file_list)
641 642
    po.close()
    po.join()
643
    result = results.get()
644

645
    # delete temp files
646 647 648 649 650
    for root, dirs, files in os.walk("./samplecode_temp"):
        for fntemp in files:
            os.remove("./samplecode_temp/" + fntemp)
    os.rmdir("./samplecode_temp")

651
    print("----------------End of the Check--------------------")
652 653 654 655 656
    for temp in result:
        if not temp:
            print("Mistakes found in sample codes")
            exit(1)
    print("Sample code check is successful!")