get_pr_ut.py 19.1 KB
Newer Older
C
chalsliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" For the PR that only modified the unit test, get cases in pull request. """

import os
import json
18
import re
19 20
import time
import subprocess
21
import requests
22 23
import urllib.request
import ssl
Y
YUNSHEN XIE 已提交
24
import platform
C
chalsliu 已提交
25 26 27
from github import Github

PADDLE_ROOT = os.getenv('PADDLE_ROOT', '/paddle/')
C
chalsliu 已提交
28 29
PADDLE_ROOT += '/'
PADDLE_ROOT = PADDLE_ROOT.replace('//', '/')
30
ssl._create_default_https_context = ssl._create_unverified_context
C
chalsliu 已提交
31 32 33


class PRChecker(object):
34
    """PR Checker."""
C
chalsliu 已提交
35 36

    def __init__(self):
37
        self.github = Github(os.getenv('GITHUB_API_TOKEN'), timeout=60)
C
chalsliu 已提交
38
        self.repo = self.github.get_repo('PaddlePaddle/Paddle')
39
        self.py_prog_oneline = re.compile(r'\d+\|\s*#.*')
40 41
        self.py_prog_multiline_a = re.compile('"""(.*?)"""', re.DOTALL)
        self.py_prog_multiline_b = re.compile("'''(.*?)'''", re.DOTALL)
42 43 44
        self.cc_prog_online = re.compile(r'\d+\|\s*//.*')
        self.cc_prog_multiline = re.compile(r'\d+\|\s*/\*.*?\*/', re.DOTALL)
        self.lineno_prog = re.compile(r'@@ \-\d+,\d+ \+(\d+),(\d+) @@')
C
chalsliu 已提交
45
        self.pr = None
46
        self.suffix = ''
C
chalsliu 已提交
47
        self.full_case = False
C
chalsliu 已提交
48 49

    def init(self):
50
        """Get pull request."""
C
chalsliu 已提交
51 52
        pr_id = os.getenv('GIT_PR_ID')
        if not pr_id:
53
            print('PREC No PR ID')
C
chalsliu 已提交
54
            exit(0)
55 56 57
        suffix = os.getenv('PREC_SUFFIX')
        if suffix:
            self.suffix = suffix
C
chalsliu 已提交
58
        self.pr = self.repo.get_pull(int(pr_id))
C
chalsliu 已提交
59 60 61
        last_commit = None
        ix = 0
        while True:
Y
YUNSHEN XIE 已提交
62 63 64 65 66 67
            try:
                commits = self.pr.get_commits().get_page(ix)
                if len(commits) == 0:
                    raise ValueError("no commit found in {} page".format(ix))
                last_commit = commits[-1].commit
            except Exception as e:
C
chalsliu 已提交
68
                break
Y
YUNSHEN XIE 已提交
69 70
            else:
                ix = ix + 1
71 72
        if last_commit.message.find('test=allcase') != -1:
            print('PREC test=allcase is set')
C
chalsliu 已提交
73
            self.full_case = True
C
chalsliu 已提交
74

75
    # todo: exception
76 77 78 79 80 81 82
    def __wget_with_retry(self, url):
        ix = 1
        proxy = '--no-proxy'
        while ix < 6:
            if ix // 2 == 0:
                proxy = ''
            else:
83 84 85 86
                if platform.system() == 'Windows':
                    proxy = '-Y off'
                else:
                    proxy = '--no-proxy'
87 88
            code = subprocess.call(
                'wget -q {} --no-check-certificate {}'.format(proxy, url),
89 90
                shell=True,
            )
91 92 93
            if code == 0:
                return True
            print(
94 95 96 97
                'PREC download {} error, retry {} time(s) after {} secs.[proxy_option={}]'.format(
                    url, ix, ix * 10, proxy
                )
            )
98 99 100 101
            time.sleep(ix * 10)
            ix += 1
        return False

102 103 104 105 106 107 108 109 110
    def __urlretrieve(self, url, filename):
        ix = 1
        with_proxy = urllib.request.getproxies()
        without_proxy = {'http': '', 'http': ''}
        while ix < 6:
            if ix // 2 == 0:
                cur_proxy = urllib.request.ProxyHandler(without_proxy)
            else:
                cur_proxy = urllib.request.ProxyHandler(with_proxy)
111 112 113
            opener = urllib.request.build_opener(
                cur_proxy, urllib.request.HTTPHandler
            )
114 115 116 117 118 119
            urllib.request.install_opener(opener)
            try:
                urllib.request.urlretrieve(url, filename)
            except Exception as e:
                print(e)
                print(
120 121 122 123
                    'PREC download {} error, retry {} time(s) after {} secs.[proxy_option={}]'.format(
                        url, ix, ix * 10, cur_proxy
                    )
                )
124 125 126 127 128 129 130 131
                continue
            else:
                return True
            time.sleep(ix * 10)
            ix += 1

        return False

C
chalsliu 已提交
132
    def get_pr_files(self):
133
        """Get files in pull request."""
C
chalsliu 已提交
134
        page = 0
Z
zhangchunle 已提交
135
        file_dict = {}
136
        file_count = 0
C
chalsliu 已提交
137 138 139 140 141
        while True:
            files = self.pr.get_files().get_page(page)
            if not files:
                break
            for f in files:
Z
zhangchunle 已提交
142
                file_dict[PADDLE_ROOT + f.filename] = f.status
143
                file_count += 1
144
            if file_count == 30:  # if pr file count = 31, nend to run all case
145
                break
C
chalsliu 已提交
146
            page += 1
Z
zhangchunle 已提交
147 148 149 150
        print("pr modify files: %s" % file_dict)
        return file_dict

    def get_is_white_file(self, filename):
151
        """judge is white file in pr's files."""
Z
zhangchunle 已提交
152
        isWhiteFile = False
153 154 155 156 157 158 159 160
        not_white_files = (
            PADDLE_ROOT + 'cmake/',
            PADDLE_ROOT + 'patches/',
            PADDLE_ROOT + 'tools/dockerfile/',
            PADDLE_ROOT + 'tools/windows/',
            PADDLE_ROOT + 'tools/test_runner.py',
            PADDLE_ROOT + 'tools/parallel_UT_rule.py',
        )
Z
zhangchunle 已提交
161 162
        if 'cmakelist' in filename.lower():
            isWhiteFile = False
163
        elif filename.startswith((not_white_files)):
Z
zhangchunle 已提交
164 165 166 167
            isWhiteFile = False
        else:
            isWhiteFile = True
        return isWhiteFile
C
chalsliu 已提交
168

169 170 171 172 173
    def __get_comment_by_filetype(self, content, filetype):
        result = []
        if filetype == 'py':
            result = self.__get_comment_by_prog(content, self.py_prog_oneline)
            result.extend(
174 175
                self.__get_comment_by_prog(content, self.py_prog_multiline_a)
            )
176
            result.extend(
177 178
                self.__get_comment_by_prog(content, self.py_prog_multiline_b)
            )
179 180 181
        if filetype == 'cc':
            result = self.__get_comment_by_prog(content, self.cc_prog_oneline)
            result.extend(
182 183
                self.__get_comment_by_prog(content, self.cc_prog_multiline)
            )
184 185 186 187 188
        return result

    def __get_comment_by_prog(self, content, prog):
        result_list = prog.findall(content)
        if not result_list:
C
chalsliu 已提交
189 190
            return []
        result = []
191 192 193 194 195
        for u in result_list:
            result.extend(u.split('\n'))
        return result

    def get_comment_of_file(self, f):
196 197
        # content = self.repo.get_contents(f.replace(PADDLE_ROOT, ''), 'pull/').decoded_content
        # todo: get file from github
198
        with open(f, encoding="utf-8") as fd:
199 200 201 202
            lines = fd.readlines()
        lineno = 1
        inputs = ''
        for line in lines:
203 204
            # for line in content.split('\n'):
            # input += str(lineno) + '|' + line + '\n'
205 206 207 208 209 210 211 212
            inputs += str(lineno) + '|' + line
            lineno += 1
        fietype = ''
        if f.endswith('.h') or f.endswith('.cc') or f.endswith('.cu'):
            filetype = 'cc'
        if f.endswith('.py'):
            filetype = 'py'
        else:
C
chalsliu 已提交
213
            return []
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
        return self.__get_comment_by_filetype(inputs, filetype)

    def get_pr_diff_lines(self):
        file_to_diff_lines = {}
        r = requests.get(self.pr.diff_url)
        data = r.text
        data = data.split('\n')
        ix = 0
        while ix < len(data):
            if data[ix].startswith('+++'):
                if data[ix].rstrip('\r\n') == '+++ /dev/null':
                    ix += 1
                    continue
                filename = data[ix][6:]
                ix += 1
                while ix < len(data):
                    result = self.lineno_prog.match(data[ix])
                    if not result:
                        break
                    lineno = int(result.group(1))
                    length = int(result.group(2))
                    ix += 1
                    end = ix + length
                    while ix < end:
                        if data[ix][0] == '-':
                            end += 1
                        if data[ix][0] == '+':
                            line_list = file_to_diff_lines.get(filename)
242 243 244
                            line = '{}{}'.format(
                                lineno, data[ix].replace('+', '|', 1)
                            )
245 246 247
                            if line_list:
                                line_list.append(line)
                            else:
248 249 250
                                file_to_diff_lines[filename] = [
                                    line,
                                ]
251 252 253 254 255 256 257 258 259
                        if data[ix][0] != '-':
                            lineno += 1
                        ix += 1
            ix += 1
        return file_to_diff_lines

    def is_only_comment(self, f):
        file_to_diff_lines = self.get_pr_diff_lines()
        comment_lines = self.get_comment_of_file(f)
C
chalsliu 已提交
260 261 262
        diff_lines = file_to_diff_lines.get(f.replace(PADDLE_ROOT, '', 1))
        if not diff_lines:
            return False
263 264 265
        for l in diff_lines:
            if l not in comment_lines:
                return False
266
        print('PREC {} is only comment'.format(f))
267 268
        return True

Z
zhangchunle 已提交
269
    def get_all_count(self):
270 271 272 273 274
        p = subprocess.Popen(
            "cd {}build && ctest -N".format(PADDLE_ROOT),
            shell=True,
            stdout=subprocess.PIPE,
        )
Y
YUNSHEN XIE 已提交
275 276 277 278 279
        out, err = p.communicate()
        for line in out.splitlines():
            if 'Total Tests:' in str(line):
                all_counts = line.split()[-1]
        return int(all_counts)
Z
zhangchunle 已提交
280

R
risemeup1 已提交
281
    def file_is_unnit_test(self, unittest_path):
282
        # get all testcases by ctest-N
283 284 285
        all_ut_file = '%s/build/all_ut_file' % PADDLE_ROOT
        os.system(
            "cd %s/build && ctest -N | awk -F ': ' '{print $2}' | sed '/^$/d' | sed '$d' > %s"
286 287
            % (PADDLE_ROOT, all_ut_file)
        )
R
risemeup1 已提交
288
        (unittest_directory, unittest_name) = os.path.split(unittest_path)
289
        # determine whether filename is in all_ut_case
290
        with open(all_ut_file, 'r') as f:
R
risemeup1 已提交
291 292 293 294
            all_unittests = f.readlines()
            for test in all_unittests:
                test = test.replace('\n', '').strip()
                if test == unittest_name.split(".")[0]:
295 296 297 298
                    return True
            else:
                return False

C
chalsliu 已提交
299
    def get_pr_ut(self):
300
        """Get unit tests in pull request."""
C
chalsliu 已提交
301 302
        if self.full_case:
            return ''
C
chalsliu 已提交
303
        check_added_ut = False
C
chalsliu 已提交
304 305
        ut_list = []
        file_ut_map = None
Z
zhangchunle 已提交
306

307
        ret = self.__urlretrieve(
308
            'https://paddle-docker-tar.bj.bcebos.com/tmp_test/ut_file_map.json',
309 310
            'ut_file_map.json',
        )
311 312 313
        if not ret:
            print('PREC download file_ut.json failed')
            exit(1)
Z
zhangchunle 已提交
314

Z
zhangchunle 已提交
315
        with open('ut_file_map.json') as jsonfile:
C
chalsliu 已提交
316
            file_ut_map = json.load(jsonfile)
Z
zhangchunle 已提交
317 318 319

        current_system = platform.system()
        notHitMapFiles = []
Z
zhangchunle 已提交
320
        hitMapFiles = {}
Z
zhangchunle 已提交
321
        onlyCommentsFilesOrXpu = []
Z
zhangchunle 已提交
322 323 324
        filterFiles = []
        file_list = []
        file_dict = self.get_pr_files()
325
        if len(file_dict) == 30:  # if pr file count = 31, nend to run all case
326
            return ''
Z
zhangchunle 已提交
327
        for filename in file_dict:
Z
zhangchunle 已提交
328
            if filename.startswith(PADDLE_ROOT + 'python/'):
Z
zhangchunle 已提交
329
                file_list.append(filename)
Z
zhangchunle 已提交
330
            elif filename.startswith(PADDLE_ROOT + 'paddle/'):
R
risemeup1 已提交
331
                if filename.startswith((PADDLE_ROOT + 'paddle/infrt')):
Z
zhangchunle 已提交
332 333 334
                    filterFiles.append(filename)
                elif filename.startswith(PADDLE_ROOT + 'paddle/scripts'):
                    if filename.startswith(
335 336 337 338 339
                        (
                            PADDLE_ROOT + 'paddle/scripts/paddle_build.sh',
                            PADDLE_ROOT + 'paddle/scripts/paddle_build.bat',
                        )
                    ):
Z
zhangchunle 已提交
340 341 342
                        file_list.append(filename)
                    else:
                        filterFiles.append(filename)
R
risemeup1 已提交
343
                elif (
Z
zhangbo9674 已提交
344 345 346 347
                    ('/xpu/' in filename.lower())
                    or ('/npu/' in filename.lower())
                    or ('/mlu/' in filename.lower())
                    or ('/ipu/' in filename.lower())
R
risemeup1 已提交
348 349
                ):
                    filterFiles.append(filename)
Z
zhangchunle 已提交
350 351
                else:
                    file_list.append(filename)
Z
zhangchunle 已提交
352
            else:
353
                if file_dict[filename] == 'added':
Z
zhangchunle 已提交
354 355
                    file_list.append(filename)
                else:
356
                    isWhiteFile = self.get_is_white_file(filename)
357
                    if not isWhiteFile:
358 359 360
                        file_list.append(filename)
                    else:
                        filterFiles.append(filename)
Z
zhangchunle 已提交
361 362
        if len(file_list) == 0:
            ut_list.append('filterfiles_placeholder')
Z
zhangchunle 已提交
363
            ret = self.__urlretrieve(
364
                'https://paddle-docker-tar.bj.bcebos.com/tmp_test/prec_delta',
365 366
                'prec_delta',
            )
Z
zhangchunle 已提交
367 368 369 370 371 372 373 374
            if ret:
                with open('prec_delta') as delta:
                    for ut in delta:
                        ut_list.append(ut.rstrip('\r\n'))
            else:
                print('PREC download prec_delta failed')
                exit(1)
            PRECISION_TEST_Cases_ratio = format(
375 376
                float(len(ut_list)) / float(self.get_all_count()), '.2f'
            )
Z
zhangchunle 已提交
377 378
            print("filterFiles: %s" % filterFiles)
            print("ipipe_log_param_PRECISION_TEST: true")
379 380 381 382 383 384 385
            print(
                "ipipe_log_param_PRECISION_TEST_Cases_count: %s" % len(ut_list)
            )
            print(
                "ipipe_log_param_PRECISION_TEST_Cases_ratio: %s"
                % PRECISION_TEST_Cases_ratio
            )
Z
zhangchunle 已提交
386
            return '\n'.join(ut_list)
Z
zhangchunle 已提交
387 388
        else:
            for f in file_list:
389 390 391 392 393
                if (
                    current_system == "Darwin"
                    or current_system == "Windows"
                    or self.suffix == ".py3"
                ):
Z
zhangchunle 已提交
394 395 396 397 398 399 400
                    f_judge = f.replace(PADDLE_ROOT, '/paddle/', 1)
                    f_judge = f_judge.replace('//', '/')
                else:
                    f_judge = f
                if f_judge not in file_ut_map:
                    if f_judge.endswith('.md'):
                        ut_list.append('md_placeholder')
Z
zhangchunle 已提交
401
                        onlyCommentsFilesOrXpu.append(f_judge)
402 403 404 405 406
                    elif (
                        'tests/unittests/xpu' in f_judge
                        or 'tests/unittests/npu' in f_judge
                        or 'op_npu.cc' in f_judge
                    ):
Z
zhangchunle 已提交
407 408
                        ut_list.append('xpu_npu_placeholder')
                        onlyCommentsFilesOrXpu.append(f_judge)
409
                    elif f_judge.endswith(('.h', '.cu', '.cc', '.py')):
410
                        # determine whether the new added file is a member of added_ut
411 412
                        if file_dict[f] in ['added']:
                            f_judge_in_added_ut = False
413 414 415 416 417 418
                            with open(
                                '{}/added_ut'.format(PADDLE_ROOT)
                            ) as utfile:
                                (filepath, tempfilename) = os.path.split(
                                    f_judge
                                )
419
                                for f_file in utfile:
420 421 422 423
                                    if (
                                        f_file.strip('\n')
                                        == tempfilename.split(".")[0]
                                    ):
424
                                        f_judge_in_added_ut = True
425
                            if f_judge_in_added_ut:
426 427
                                print(
                                    "Adding new unit tests not hit mapFiles: %s"
428 429
                                    % f_judge
                                )
430 431 432 433 434
                            else:
                                notHitMapFiles.append(f_judge)
                        elif file_dict[f] in ['removed']:
                            print("remove file not hit mapFiles: %s" % f_judge)
                        else:
Z
zhangchunle 已提交
435 436 437
                            if self.is_only_comment(f):
                                ut_list.append('comment_placeholder')
                                onlyCommentsFilesOrXpu.append(f_judge)
438 439
                            if self.file_is_unnit_test(f_judge):
                                ut_list.append(f_judge.split(".")[0])
Z
zhangchunle 已提交
440 441 442
                            else:
                                notHitMapFiles.append(f_judge)
                    else:
443 444 445 446 447
                        notHitMapFiles.append(f_judge) if file_dict[
                            f
                        ] != 'removed' else print(
                            "remove file not hit mapFiles: %s" % f_judge
                        )
Z
zhangchunle 已提交
448 449
                else:
                    if file_dict[f] not in ['removed']:
Z
zhangchunle 已提交
450 451 452 453
                        if self.is_only_comment(f):
                            ut_list.append('comment_placeholder')
                            onlyCommentsFilesOrXpu.append(f_judge)
                        else:
Z
zhangchunle 已提交
454 455
                            hitMapFiles[f_judge] = len(file_ut_map[f_judge])
                            ut_list.extend(file_ut_map.get(f_judge))
456
                    else:
Z
zhangchunle 已提交
457
                        hitMapFiles[f_judge] = len(file_ut_map[f_judge])
Z
zhangchunle 已提交
458
                        ut_list.extend(file_ut_map.get(f_judge))
Z
zhangchunle 已提交
459

Z
zhangchunle 已提交
460 461 462 463
            ut_list = list(set(ut_list))
            if len(notHitMapFiles) != 0:
                print("ipipe_log_param_PRECISION_TEST: false")
                print("notHitMapFiles: %s" % notHitMapFiles)
Z
zhangchunle 已提交
464 465
                if len(filterFiles) != 0:
                    print("filterFiles: %s" % filterFiles)
Z
zhangchunle 已提交
466
                return ''
C
chalsliu 已提交
467
            else:
Z
zhangchunle 已提交
468 469
                if ut_list:
                    ret = self.__urlretrieve(
470
                        'https://paddle-docker-tar.bj.bcebos.com/tmp_test/prec_delta',
471 472
                        'prec_delta',
                    )
Z
zhangchunle 已提交
473 474 475
                    if ret:
                        with open('prec_delta') as delta:
                            for ut in delta:
R
risemeup1 已提交
476 477
                                if ut not in ut_list:
                                    ut_list.append(ut.rstrip('\r\n'))
Z
zhangchunle 已提交
478 479 480
                    else:
                        print('PREC download prec_delta failed')
                        exit(1)
Z
zhangchunle 已提交
481
                    print("hitMapFiles: %s" % hitMapFiles)
Z
zhangchunle 已提交
482
                    print("ipipe_log_param_PRECISION_TEST: true")
483 484 485 486
                    print(
                        "ipipe_log_param_PRECISION_TEST_Cases_count: %s"
                        % len(ut_list)
                    )
Z
zhangchunle 已提交
487
                    PRECISION_TEST_Cases_ratio = format(
488 489 490 491 492 493
                        float(len(ut_list)) / float(self.get_all_count()), '.2f'
                    )
                    print(
                        "ipipe_log_param_PRECISION_TEST_Cases_ratio: %s"
                        % PRECISION_TEST_Cases_ratio
                    )
Z
zhangchunle 已提交
494 495
                    if len(filterFiles) != 0:
                        print("filterFiles: %s" % filterFiles)
Z
zhangchunle 已提交
496
                return '\n'.join(ut_list)
C
chalsliu 已提交
497 498 499 500 501


if __name__ == '__main__':
    pr_checker = PRChecker()
    pr_checker.init()
502 503
    with open('ut_list', 'w') as f:
        f.write(pr_checker.get_pr_ut())