get_pr_ut.py 19.5 KB
Newer Older
C
chalsliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" For the PR that only modified the unit test, get cases in pull request. """

import json
17 18
import os
import platform
19
import re
20
import ssl
21
import subprocess
22
import time
23
import urllib.request
24 25

import requests
C
chalsliu 已提交
26 27 28
from github import Github

PADDLE_ROOT = os.getenv('PADDLE_ROOT', '/paddle/')
C
chalsliu 已提交
29 30
PADDLE_ROOT += '/'
PADDLE_ROOT = PADDLE_ROOT.replace('//', '/')
31
ssl._create_default_https_context = ssl._create_unverified_context
C
chalsliu 已提交
32 33


34
class PRChecker:
35
    """PR Checker."""
C
chalsliu 已提交
36 37

    def __init__(self):
38
        self.github = Github(os.getenv('GITHUB_API_TOKEN'), timeout=60)
C
chalsliu 已提交
39
        self.repo = self.github.get_repo('PaddlePaddle/Paddle')
40
        self.py_prog_oneline = re.compile(r'\d+\|\s*#.*')
41 42
        self.py_prog_multiline_a = re.compile('"""(.*?)"""', re.DOTALL)
        self.py_prog_multiline_b = re.compile("'''(.*?)'''", re.DOTALL)
43 44 45
        self.cc_prog_online = re.compile(r'\d+\|\s*//.*')
        self.cc_prog_multiline = re.compile(r'\d+\|\s*/\*.*?\*/', re.DOTALL)
        self.lineno_prog = re.compile(r'@@ \-\d+,\d+ \+(\d+),(\d+) @@')
C
chalsliu 已提交
46
        self.pr = None
47
        self.suffix = ''
C
chalsliu 已提交
48
        self.full_case = False
C
chalsliu 已提交
49 50

    def init(self):
51
        """Get pull request."""
C
chalsliu 已提交
52 53
        pr_id = os.getenv('GIT_PR_ID')
        if not pr_id:
54
            print('PREC No PR ID')
C
chalsliu 已提交
55
            exit(0)
56 57 58
        suffix = os.getenv('PREC_SUFFIX')
        if suffix:
            self.suffix = suffix
C
chalsliu 已提交
59
        self.pr = self.repo.get_pull(int(pr_id))
C
chalsliu 已提交
60 61 62
        last_commit = None
        ix = 0
        while True:
Y
YUNSHEN XIE 已提交
63 64 65 66 67 68
            try:
                commits = self.pr.get_commits().get_page(ix)
                if len(commits) == 0:
                    raise ValueError("no commit found in {} page".format(ix))
                last_commit = commits[-1].commit
            except Exception as e:
C
chalsliu 已提交
69
                break
Y
YUNSHEN XIE 已提交
70 71
            else:
                ix = ix + 1
72 73
        if last_commit.message.find('test=allcase') != -1:
            print('PREC test=allcase is set')
C
chalsliu 已提交
74
            self.full_case = True
C
chalsliu 已提交
75

76
    # todo: exception
77 78 79 80 81 82 83
    def __wget_with_retry(self, url):
        ix = 1
        proxy = '--no-proxy'
        while ix < 6:
            if ix // 2 == 0:
                proxy = ''
            else:
84 85 86 87
                if platform.system() == 'Windows':
                    proxy = '-Y off'
                else:
                    proxy = '--no-proxy'
88 89
            code = subprocess.call(
                'wget -q {} --no-check-certificate {}'.format(proxy, url),
90 91
                shell=True,
            )
92 93 94
            if code == 0:
                return True
            print(
95 96 97 98
                'PREC download {} error, retry {} time(s) after {} secs.[proxy_option={}]'.format(
                    url, ix, ix * 10, proxy
                )
            )
99 100 101 102
            time.sleep(ix * 10)
            ix += 1
        return False

103 104 105 106 107 108 109 110 111
    def __urlretrieve(self, url, filename):
        ix = 1
        with_proxy = urllib.request.getproxies()
        without_proxy = {'http': '', 'http': ''}
        while ix < 6:
            if ix // 2 == 0:
                cur_proxy = urllib.request.ProxyHandler(without_proxy)
            else:
                cur_proxy = urllib.request.ProxyHandler(with_proxy)
112 113 114
            opener = urllib.request.build_opener(
                cur_proxy, urllib.request.HTTPHandler
            )
115 116 117 118 119 120
            urllib.request.install_opener(opener)
            try:
                urllib.request.urlretrieve(url, filename)
            except Exception as e:
                print(e)
                print(
121 122 123 124
                    'PREC download {} error, retry {} time(s) after {} secs.[proxy_option={}]'.format(
                        url, ix, ix * 10, cur_proxy
                    )
                )
125 126 127 128 129 130 131 132
                continue
            else:
                return True
            time.sleep(ix * 10)
            ix += 1

        return False

C
chalsliu 已提交
133
    def get_pr_files(self):
134
        """Get files in pull request."""
C
chalsliu 已提交
135
        page = 0
Z
zhangchunle 已提交
136
        file_dict = {}
137
        file_count = 0
C
chalsliu 已提交
138 139 140 141 142
        while True:
            files = self.pr.get_files().get_page(page)
            if not files:
                break
            for f in files:
Z
zhangchunle 已提交
143
                file_dict[PADDLE_ROOT + f.filename] = f.status
144
                file_count += 1
145
            if file_count == 30:  # if pr file count = 31, nend to run all case
146
                break
C
chalsliu 已提交
147
            page += 1
Z
zhangchunle 已提交
148 149 150 151
        print("pr modify files: %s" % file_dict)
        return file_dict

    def get_is_white_file(self, filename):
152
        """judge is white file in pr's files."""
Z
zhangchunle 已提交
153
        isWhiteFile = False
154 155 156 157 158 159 160 161
        not_white_files = (
            PADDLE_ROOT + 'cmake/',
            PADDLE_ROOT + 'patches/',
            PADDLE_ROOT + 'tools/dockerfile/',
            PADDLE_ROOT + 'tools/windows/',
            PADDLE_ROOT + 'tools/test_runner.py',
            PADDLE_ROOT + 'tools/parallel_UT_rule.py',
        )
Z
zhangchunle 已提交
162 163
        if 'cmakelist' in filename.lower():
            isWhiteFile = False
164
        elif filename.startswith((not_white_files)):
Z
zhangchunle 已提交
165 166 167 168
            isWhiteFile = False
        else:
            isWhiteFile = True
        return isWhiteFile
C
chalsliu 已提交
169

170 171 172 173 174
    def __get_comment_by_filetype(self, content, filetype):
        result = []
        if filetype == 'py':
            result = self.__get_comment_by_prog(content, self.py_prog_oneline)
            result.extend(
175 176
                self.__get_comment_by_prog(content, self.py_prog_multiline_a)
            )
177
            result.extend(
178 179
                self.__get_comment_by_prog(content, self.py_prog_multiline_b)
            )
180 181 182
        if filetype == 'cc':
            result = self.__get_comment_by_prog(content, self.cc_prog_oneline)
            result.extend(
183 184
                self.__get_comment_by_prog(content, self.cc_prog_multiline)
            )
185 186 187 188 189
        return result

    def __get_comment_by_prog(self, content, prog):
        result_list = prog.findall(content)
        if not result_list:
C
chalsliu 已提交
190 191
            return []
        result = []
192 193 194 195 196
        for u in result_list:
            result.extend(u.split('\n'))
        return result

    def get_comment_of_file(self, f):
197 198
        # content = self.repo.get_contents(f.replace(PADDLE_ROOT, ''), 'pull/').decoded_content
        # todo: get file from github
199
        with open(f, encoding="utf-8") as fd:
200 201 202 203
            lines = fd.readlines()
        lineno = 1
        inputs = ''
        for line in lines:
204 205
            # for line in content.split('\n'):
            # input += str(lineno) + '|' + line + '\n'
206 207 208 209 210 211 212 213
            inputs += str(lineno) + '|' + line
            lineno += 1
        fietype = ''
        if f.endswith('.h') or f.endswith('.cc') or f.endswith('.cu'):
            filetype = 'cc'
        if f.endswith('.py'):
            filetype = 'py'
        else:
C
chalsliu 已提交
214
            return []
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        return self.__get_comment_by_filetype(inputs, filetype)

    def get_pr_diff_lines(self):
        file_to_diff_lines = {}
        r = requests.get(self.pr.diff_url)
        data = r.text
        data = data.split('\n')
        ix = 0
        while ix < len(data):
            if data[ix].startswith('+++'):
                if data[ix].rstrip('\r\n') == '+++ /dev/null':
                    ix += 1
                    continue
                filename = data[ix][6:]
                ix += 1
                while ix < len(data):
                    result = self.lineno_prog.match(data[ix])
                    if not result:
                        break
                    lineno = int(result.group(1))
                    length = int(result.group(2))
                    ix += 1
                    end = ix + length
                    while ix < end:
                        if data[ix][0] == '-':
                            end += 1
                        if data[ix][0] == '+':
                            line_list = file_to_diff_lines.get(filename)
243 244 245
                            line = '{}{}'.format(
                                lineno, data[ix].replace('+', '|', 1)
                            )
246 247 248
                            if line_list:
                                line_list.append(line)
                            else:
249 250 251
                                file_to_diff_lines[filename] = [
                                    line,
                                ]
252 253 254 255 256 257 258 259 260
                        if data[ix][0] != '-':
                            lineno += 1
                        ix += 1
            ix += 1
        return file_to_diff_lines

    def is_only_comment(self, f):
        file_to_diff_lines = self.get_pr_diff_lines()
        comment_lines = self.get_comment_of_file(f)
C
chalsliu 已提交
261 262 263
        diff_lines = file_to_diff_lines.get(f.replace(PADDLE_ROOT, '', 1))
        if not diff_lines:
            return False
264 265 266
        for l in diff_lines:
            if l not in comment_lines:
                return False
267
        print('PREC {} is only comment'.format(f))
268 269
        return True

Z
zhangchunle 已提交
270
    def get_all_count(self):
271 272 273 274 275
        p = subprocess.Popen(
            "cd {}build && ctest -N".format(PADDLE_ROOT),
            shell=True,
            stdout=subprocess.PIPE,
        )
Y
YUNSHEN XIE 已提交
276 277 278 279 280
        out, err = p.communicate()
        for line in out.splitlines():
            if 'Total Tests:' in str(line):
                all_counts = line.split()[-1]
        return int(all_counts)
Z
zhangchunle 已提交
281

R
risemeup1 已提交
282
    def file_is_unnit_test(self, unittest_path):
283
        # get all testcases by ctest-N
R
risemeup1 已提交
284 285 286 287 288 289
        all_ut_file = PADDLE_ROOT + 'build/all_ut_list'
        # all_ut_file = '%s/build/all_ut_file' % PADDLE_ROOT
        print("PADDLE_ROOT:", PADDLE_ROOT)
        print("all_ut_file path:", all_ut_file)
        build_path = PADDLE_ROOT + 'build/'
        print("build_path:", build_path)
R
risemeup1 已提交
290
        (unittest_directory, unittest_name) = os.path.split(unittest_path)
291
        # determine whether filename is in all_ut_case
292
        with open(all_ut_file, 'r') as f:
R
risemeup1 已提交
293 294 295 296
            all_unittests = f.readlines()
            for test in all_unittests:
                test = test.replace('\n', '').strip()
                if test == unittest_name.split(".")[0]:
297
                    return True
R
risemeup1 已提交
298
        return False
299

C
chalsliu 已提交
300
    def get_pr_ut(self):
301
        """Get unit tests in pull request."""
C
chalsliu 已提交
302 303
        if self.full_case:
            return ''
C
chalsliu 已提交
304
        check_added_ut = False
C
chalsliu 已提交
305 306
        ut_list = []
        file_ut_map = None
Z
zhangchunle 已提交
307

308
        ret = self.__urlretrieve(
309
            'https://paddle-docker-tar.bj.bcebos.com/precision_test_map_store/ut_file_map.json',
310 311
            'ut_file_map.json',
        )
312 313 314
        if not ret:
            print('PREC download file_ut.json failed')
            exit(1)
Z
zhangchunle 已提交
315

Z
zhangchunle 已提交
316
        with open('ut_file_map.json') as jsonfile:
C
chalsliu 已提交
317
            file_ut_map = json.load(jsonfile)
Z
zhangchunle 已提交
318 319 320

        current_system = platform.system()
        notHitMapFiles = []
Z
zhangchunle 已提交
321
        hitMapFiles = {}
Z
zhangchunle 已提交
322
        onlyCommentsFilesOrXpu = []
Z
zhangchunle 已提交
323 324 325
        filterFiles = []
        file_list = []
        file_dict = self.get_pr_files()
326
        if len(file_dict) == 30:  # if pr file count = 31, nend to run all case
327
            return ''
Z
zhangchunle 已提交
328
        for filename in file_dict:
Z
zhangchunle 已提交
329
            if filename.startswith(PADDLE_ROOT + 'python/'):
Z
zhangchunle 已提交
330
                file_list.append(filename)
Z
zhangchunle 已提交
331
            elif filename.startswith(PADDLE_ROOT + 'paddle/'):
R
risemeup1 已提交
332
                if filename.startswith((PADDLE_ROOT + 'paddle/infrt')):
Z
zhangchunle 已提交
333 334 335
                    filterFiles.append(filename)
                elif filename.startswith(PADDLE_ROOT + 'paddle/scripts'):
                    if filename.startswith(
336 337 338 339 340
                        (
                            PADDLE_ROOT + 'paddle/scripts/paddle_build.sh',
                            PADDLE_ROOT + 'paddle/scripts/paddle_build.bat',
                        )
                    ):
Z
zhangchunle 已提交
341 342 343
                        file_list.append(filename)
                    else:
                        filterFiles.append(filename)
R
risemeup1 已提交
344
                elif (
Z
zhangbo9674 已提交
345 346 347 348
                    ('/xpu/' in filename.lower())
                    or ('/npu/' in filename.lower())
                    or ('/mlu/' in filename.lower())
                    or ('/ipu/' in filename.lower())
R
risemeup1 已提交
349 350
                ):
                    filterFiles.append(filename)
Z
zhangchunle 已提交
351 352
                else:
                    file_list.append(filename)
Z
zhangchunle 已提交
353
            else:
354
                if file_dict[filename] == 'added':
Z
zhangchunle 已提交
355 356
                    file_list.append(filename)
                else:
357
                    isWhiteFile = self.get_is_white_file(filename)
358
                    if not isWhiteFile:
359 360 361
                        file_list.append(filename)
                    else:
                        filterFiles.append(filename)
Z
zhangchunle 已提交
362 363
        if len(file_list) == 0:
            ut_list.append('filterfiles_placeholder')
Z
zhangchunle 已提交
364
            ret = self.__urlretrieve(
365
                'https://paddle-docker-tar.bj.bcebos.com/precision_test_map_store/prec_delta',
366 367
                'prec_delta',
            )
Z
zhangchunle 已提交
368 369 370 371 372 373 374 375
            if ret:
                with open('prec_delta') as delta:
                    for ut in delta:
                        ut_list.append(ut.rstrip('\r\n'))
            else:
                print('PREC download prec_delta failed')
                exit(1)
            PRECISION_TEST_Cases_ratio = format(
376 377
                float(len(ut_list)) / float(self.get_all_count()), '.2f'
            )
Z
zhangchunle 已提交
378 379
            print("filterFiles: %s" % filterFiles)
            print("ipipe_log_param_PRECISION_TEST: true")
380 381 382 383 384 385 386
            print(
                "ipipe_log_param_PRECISION_TEST_Cases_count: %s" % len(ut_list)
            )
            print(
                "ipipe_log_param_PRECISION_TEST_Cases_ratio: %s"
                % PRECISION_TEST_Cases_ratio
            )
R
risemeup1 已提交
387 388 389 390
            print(
                "The unittests in prec delta is shown as following: %s"
                % ut_list
            )
Z
zhangchunle 已提交
391
            return '\n'.join(ut_list)
Z
zhangchunle 已提交
392 393
        else:
            for f in file_list:
394 395 396 397 398
                if (
                    current_system == "Darwin"
                    or current_system == "Windows"
                    or self.suffix == ".py3"
                ):
Z
zhangchunle 已提交
399 400 401 402 403 404 405
                    f_judge = f.replace(PADDLE_ROOT, '/paddle/', 1)
                    f_judge = f_judge.replace('//', '/')
                else:
                    f_judge = f
                if f_judge not in file_ut_map:
                    if f_judge.endswith('.md'):
                        ut_list.append('md_placeholder')
Z
zhangchunle 已提交
406
                        onlyCommentsFilesOrXpu.append(f_judge)
407 408 409 410 411
                    elif (
                        'tests/unittests/xpu' in f_judge
                        or 'tests/unittests/npu' in f_judge
                        or 'op_npu.cc' in f_judge
                    ):
Z
zhangchunle 已提交
412 413
                        ut_list.append('xpu_npu_placeholder')
                        onlyCommentsFilesOrXpu.append(f_judge)
414
                    elif f_judge.endswith(('.h', '.cu', '.cc', '.py')):
415
                        # determine whether the new added file is a member of added_ut
416 417
                        if file_dict[f] in ['added']:
                            f_judge_in_added_ut = False
R
risemeup1 已提交
418 419 420 421 422 423 424 425 426 427 428
                            path = PADDLE_ROOT + 'added_ut'
                            print("PADDLE_ROOT:", PADDLE_ROOT)
                            print("adde_ut path:", path)
                            (unittest_directory, unittest_name) = os.path.split(
                                f_judge
                            )
                            with open(path, 'r') as f:
                                added_unittests = f.readlines()
                                for test in added_unittests:
                                    test = test.replace('\n', '').strip()
                                    if test == unittest_name.split(".")[0]:
429
                                        f_judge_in_added_ut = True
430
                            if f_judge_in_added_ut:
431 432
                                print(
                                    "Adding new unit tests not hit mapFiles: %s"
433 434
                                    % f_judge
                                )
435 436 437 438 439
                            else:
                                notHitMapFiles.append(f_judge)
                        elif file_dict[f] in ['removed']:
                            print("remove file not hit mapFiles: %s" % f_judge)
                        else:
Z
zhangchunle 已提交
440 441 442
                            if self.is_only_comment(f):
                                ut_list.append('comment_placeholder')
                                onlyCommentsFilesOrXpu.append(f_judge)
443
                            if self.file_is_unnit_test(f_judge):
R
risemeup1 已提交
444 445 446
                                ut_list.append(
                                    os.path.split(f_judge)[1].split(".")[0]
                                )
Z
zhangchunle 已提交
447 448 449
                            else:
                                notHitMapFiles.append(f_judge)
                    else:
450 451 452 453 454
                        notHitMapFiles.append(f_judge) if file_dict[
                            f
                        ] != 'removed' else print(
                            "remove file not hit mapFiles: %s" % f_judge
                        )
Z
zhangchunle 已提交
455 456
                else:
                    if file_dict[f] not in ['removed']:
Z
zhangchunle 已提交
457 458 459 460
                        if self.is_only_comment(f):
                            ut_list.append('comment_placeholder')
                            onlyCommentsFilesOrXpu.append(f_judge)
                        else:
Z
zhangchunle 已提交
461 462
                            hitMapFiles[f_judge] = len(file_ut_map[f_judge])
                            ut_list.extend(file_ut_map.get(f_judge))
463
                    else:
Z
zhangchunle 已提交
464
                        hitMapFiles[f_judge] = len(file_ut_map[f_judge])
Z
zhangchunle 已提交
465
                        ut_list.extend(file_ut_map.get(f_judge))
Z
zhangchunle 已提交
466

Z
zhangchunle 已提交
467 468 469 470
            ut_list = list(set(ut_list))
            if len(notHitMapFiles) != 0:
                print("ipipe_log_param_PRECISION_TEST: false")
                print("notHitMapFiles: %s" % notHitMapFiles)
Z
zhangchunle 已提交
471 472
                if len(filterFiles) != 0:
                    print("filterFiles: %s" % filterFiles)
Z
zhangchunle 已提交
473
                return ''
C
chalsliu 已提交
474
            else:
Z
zhangchunle 已提交
475 476
                if ut_list:
                    ret = self.__urlretrieve(
477
                        'https://paddle-docker-tar.bj.bcebos.com/precision_test_map_store/prec_delta',
478 479
                        'prec_delta',
                    )
Z
zhangchunle 已提交
480 481 482
                    if ret:
                        with open('prec_delta') as delta:
                            for ut in delta:
R
risemeup1 已提交
483 484
                                if ut not in ut_list:
                                    ut_list.append(ut.rstrip('\r\n'))
Z
zhangchunle 已提交
485 486 487
                    else:
                        print('PREC download prec_delta failed')
                        exit(1)
Z
zhangchunle 已提交
488
                    print("hitMapFiles: %s" % hitMapFiles)
Z
zhangchunle 已提交
489
                    print("ipipe_log_param_PRECISION_TEST: true")
490 491 492 493
                    print(
                        "ipipe_log_param_PRECISION_TEST_Cases_count: %s"
                        % len(ut_list)
                    )
Z
zhangchunle 已提交
494
                    PRECISION_TEST_Cases_ratio = format(
495 496 497 498 499 500
                        float(len(ut_list)) / float(self.get_all_count()), '.2f'
                    )
                    print(
                        "ipipe_log_param_PRECISION_TEST_Cases_ratio: %s"
                        % PRECISION_TEST_Cases_ratio
                    )
Z
zhangchunle 已提交
501 502
                    if len(filterFiles) != 0:
                        print("filterFiles: %s" % filterFiles)
Z
zhangchunle 已提交
503
                return '\n'.join(ut_list)
C
chalsliu 已提交
504 505 506 507 508


if __name__ == '__main__':
    pr_checker = PRChecker()
    pr_checker.init()
509 510
    with open('ut_list', 'w') as f:
        f.write(pr_checker.get_pr_ut())