check_op_benchmark_result.py 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import logging
import argparse


def check_path_exists(path):
22
    """Assert whether file/directory exists."""
23 24 25
    assert os.path.exists(path), "%s does not exist." % path


26
def parse_case_name(log_file_name):
27
    """Parse case name."""
28 29 30
    case_id, case_info = log_file_name.split("-")
    direction = case_info.split(".")[0].split("_")[-1]

31
    return "%s (%s)" % (case_id, direction)
32 33


34
def parse_log_file(log_file):
35
    """Load one case result from log file."""
36 37 38 39 40 41 42
    check_path_exists(log_file)

    result = None
    with open(log_file) as f:
        for line in f.read().strip().split('\n')[::-1]:
            try:
                result = json.loads(line)
43
                if result.get("disabled", False):
44
                    return None
45 46 47 48
                return result
            except ValueError:
                pass  # do nothing

49 50 51 52
    if result is None:
        logging.warning("Parse %s fail!" % log_file)

    return result
53 54 55


def load_benchmark_result_from_logs_dir(logs_dir):
56
    """Load benchmark result from logs directory."""
57 58 59
    check_path_exists(logs_dir)

    log_file_path = lambda log_file: os.path.join(logs_dir, log_file)
60 61 62 63
    result_lambda = lambda log_file: (
        log_file,
        parse_log_file(log_file_path(log_file)),
    )
64 65 66 67

    return dict(map(result_lambda, os.listdir(logs_dir)))


W
wuhuanzhou 已提交
68
def check_speed_result(case_name, develop_data, pr_data, pr_result):
69
    """Check speed differences between develop and pr."""
W
wuhuanzhou 已提交
70 71
    pr_gpu_time = pr_data.get("gpu_time")
    develop_gpu_time = develop_data.get("gpu_time")
72 73 74 75 76 77
    if develop_gpu_time != 0.0:
        gpu_time_diff = (pr_gpu_time - develop_gpu_time) / develop_gpu_time
        gpu_time_diff_str = "{:.5f}".format(gpu_time_diff * 100)
    else:
        gpu_time_diff = None
        gpu_time_diff_str = ""
W
wuhuanzhou 已提交
78 79 80 81 82 83

    pr_total_time = pr_data.get("total")
    develop_total_time = develop_data.get("total")
    total_time_diff = (pr_total_time - develop_total_time) / develop_total_time

    logging.info("------ OP: %s ------" % case_name)
84 85 86 87 88 89 90 91
    logging.info(
        "GPU time change: %s (develop: %.7f -> PR: %.7f)"
        % (gpu_time_diff_str, develop_gpu_time, pr_gpu_time)
    )
    logging.info(
        "Total time change: %.5f%% (develop: %.7f -> PR: %.7f)"
        % (total_time_diff * 100, develop_total_time, pr_total_time)
    )
W
wuhuanzhou 已提交
92 93 94 95 96 97 98 99 100
    logging.info("backward: %s" % pr_result.get("backward"))
    logging.info("parameters:")
    for line in pr_result.get("parameters").strip().split("\n"):
        logging.info("\t%s" % line)

    return gpu_time_diff > 0.05


def check_accuracy_result(case_name, pr_result):
101
    """Check accuracy result."""
W
wuhuanzhou 已提交
102 103 104 105 106 107 108 109 110 111
    logging.info("------ OP: %s ------" % case_name)
    logging.info("Accuracy diff: %s" % pr_result.get("diff"))
    logging.info("backward: %s" % pr_result.get("backward"))
    logging.info("parameters:")
    for line in pr_result.get("parameters").strip().split("\n"):
        logging.info("\t%s" % line)

    return not pr_result.get("consistent")


112 113 114 115
def compare_benchmark_result(
    case_name, develop_result, pr_result, check_results
):
    """Compare the differences between develop and pr."""
116 117 118 119
    develop_speed = develop_result.get("speed")
    pr_speed = pr_result.get("speed")

    assert type(develop_speed) == type(
120 121
        pr_speed
    ), "The types of comparison results need to be consistent."
122 123

    if isinstance(develop_speed, dict) and isinstance(pr_speed, dict):
W
wuhuanzhou 已提交
124 125
        if check_speed_result(case_name, develop_speed, pr_speed, pr_result):
            check_results["speed"].append(case_name)
126
    else:
W
wuhuanzhou 已提交
127 128
        if check_accuracy_result(case_name, pr_result):
            check_results["accuracy"].append(case_name)
129

W
wuhuanzhou 已提交
130

W
wuhuanzhou 已提交
131
def update_api_info_file(fail_case_list, api_info_file):
132
    """Update api info file to auto retry benchmark test."""
W
wuhuanzhou 已提交
133 134 135
    check_path_exists(api_info_file)

    # set of case names for performance check failures
136 137
    parse_case_id_f = lambda x: x.split()[0].rsplit('_', 1)
    fail_case_dict = dict(map(parse_case_id_f, fail_case_list))
W
wuhuanzhou 已提交
138 139 140 141 142

    # list of api infos for performance check failures
    api_info_list = list()
    with open(api_info_file) as f:
        for line in f:
143 144 145 146 147
            line_list = line.split(',')
            case = line_list[0].split(':')[0]
            if case in fail_case_dict:
                line_list[0] = "%s:%s" % (case, fail_case_dict[case])
                api_info_list.append(','.join(line_list))
W
wuhuanzhou 已提交
148 149 150 151 152 153 154 155

    # update api info file
    with open(api_info_file, 'w') as f:
        for api_info_line in api_info_list:
            f.write(api_info_line)


def summary_results(check_results, api_info_file):
156
    """Summary results and return exit code."""
W
wuhuanzhou 已提交
157 158 159 160
    for case_name in check_results["speed"]:
        logging.error("Check speed result with case \"%s\" failed." % case_name)

    for case_name in check_results["accuracy"]:
161 162 163
        logging.error(
            "Check accuracy result with case \"%s\" failed." % case_name
        )
W
wuhuanzhou 已提交
164

W
wuhuanzhou 已提交
165 166 167
    if len(check_results["speed"]) and api_info_file:
        update_api_info_file(check_results["speed"], api_info_file)

W
wuhuanzhou 已提交
168 169 170 171
    if len(check_results["speed"]) or len(check_results["accuracy"]):
        return 8
    else:
        return 0
172 173 174


if __name__ == "__main__":
175
    """Load result from log directories and compare the differences."""
176 177
    logging.basicConfig(
        level=logging.INFO,
178 179
        format="[%(filename)s:%(lineno)d] [%(levelname)s] %(message)s",
    )
180 181 182 183 184 185

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--develop_logs_dir",
        type=str,
        required=True,
186 187
        help="Specify the benchmark result directory of develop branch.",
    )
188 189 190 191
    parser.add_argument(
        "--pr_logs_dir",
        type=str,
        required=True,
192 193 194 195 196 197 198 199
        help="Specify the benchmark result directory of PR branch.",
    )
    parser.add_argument(
        "--api_info_file",
        type=str,
        required=False,
        help="Specify the api info to run benchmark test.",
    )
200 201
    args = parser.parse_args()

W
wuhuanzhou 已提交
202
    check_results = dict(accuracy=list(), speed=list())
203

204
    develop_result_dict = load_benchmark_result_from_logs_dir(
205 206
        args.develop_logs_dir
    )
207 208

    check_path_exists(args.pr_logs_dir)
209 210
    pr_log_files = os.listdir(args.pr_logs_dir)
    for log_file in sorted(pr_log_files):
211 212 213 214
        develop_result = develop_result_dict.get(log_file)
        pr_result = parse_log_file(os.path.join(args.pr_logs_dir, log_file))
        if develop_result is None or pr_result is None:
            continue
215
        case_name = parse_case_name(log_file)
216 217 218
        compare_benchmark_result(
            case_name, develop_result, pr_result, check_results
        )
219

W
wuhuanzhou 已提交
220
    exit(summary_results(check_results, args.api_info_file))