predict.py 5.4 KB
Newer Older
W
WuHaobo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
W
WuHaobo 已提交
16
import numpy as np
17
import cv2
S
fix  
shippingwang 已提交
18
import time
W
WuHaobo 已提交
19

L
littletomatodonkey 已提交
20 21
import sys
sys.path.insert(0, ".")
T
Tingquan Gao 已提交
22
from ppcls.utils import logger
23 24
from tools.infer.utils import parse_args, create_paddle_predictor, preprocess, postprocess
from tools.infer.utils import get_image_list, get_image_list_from_label_file, calc_topk_acc
T
Tingquan Gao 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49


class Predictor(object):
    def __init__(self, args):
        # HALF precission predict only work when using tensorrt
        if args.use_fp16 is True:
            assert args.use_tensorrt is True
        self.args = args

        self.paddle_predictor = create_paddle_predictor(args)
        input_names = self.paddle_predictor.get_input_names()
        self.input_tensor = self.paddle_predictor.get_input_handle(input_names[
            0])

        output_names = self.paddle_predictor.get_output_names()
        self.output_tensor = self.paddle_predictor.get_output_handle(
            output_names[0])

    def predict(self, batch_input):
        self.input_tensor.copy_from_cpu(batch_input)
        self.paddle_predictor.run()
        batch_output = self.output_tensor.copy_to_cpu()
        return batch_output

    def normal_predict(self):
50 51 52 53 54 55 56 57 58 59 60 61 62
        if self.args.enable_calc_topk:
            assert self.args.gt_label_path is not None and os.path.exists(self.args.gt_label_path), \
                "gt_label_path shoule not be None and must exist, please check its path."
            image_list, gt_labels = get_image_list_from_label_file(
                self.args.image_file, self.args.gt_label_path)
            predicts_map = {
                "prediction": [],
                "gt_label": [],
            }
        else:
            image_list = get_image_list(self.args.image_file)
            gt_labels = None

T
Tingquan Gao 已提交
63 64 65 66 67 68 69 70 71 72
        batch_input_list = []
        img_name_list = []
        cnt = 0
        for idx, img_path in enumerate(image_list):
            img = cv2.imread(img_path)
            if img is None:
                logger.warning(
                    "Image file failed to read and has been skipped. The path: {}".
                    format(img_path))
                continue
L
littletomatodonkey 已提交
73
            else:
T
Tingquan Gao 已提交
74 75 76 77 78 79
                img = img[:, :, ::-1]
                img = preprocess(img, args)
                batch_input_list.append(img)
                img_name = img_path.split("/")[-1]
                img_name_list.append(img_name)
                cnt += 1
80 81
                if self.args.enable_calc_topk:
                    predicts_map["gt_label"].append(gt_labels[idx])
T
Tingquan Gao 已提交
82 83 84 85 86 87 88 89 90 91

            if cnt % args.batch_size == 0 or (idx + 1) == len(image_list):
                batch_outputs = self.predict(np.array(batch_input_list))
                batch_result_list = postprocess(batch_outputs, self.args.top_k)

                for number, result_dict in enumerate(batch_result_list):
                    filename = img_name_list[number]
                    clas_ids = result_dict["clas_ids"]
                    scores_str = "[{}]".format(", ".join("{:.2f}".format(
                        r) for r in result_dict["scores"]))
92
                    logger.info(
T
Tingquan Gao 已提交
93 94 95
                        "File:{}, Top-{} result: class id(s): {}, score(s): {}".
                        format(filename, self.args.top_k, clas_ids,
                               scores_str))
96 97 98 99

                    if self.args.enable_calc_topk:
                        predicts_map["prediction"].append(clas_ids)

T
Tingquan Gao 已提交
100 101
                batch_input_list = []
                img_name_list = []
102 103 104 105
        if self.args.enable_calc_topk:
            topk_acc = calc_topk_acc(predicts_map)
            for idx, acc in enumerate(topk_acc):
                logger.info("Top-{} acc: {:.5f}".format(idx + 1, acc))
T
Tingquan Gao 已提交
106 107 108 109

    def benchmark_predict(self):
        test_num = 500
        test_time = 0.0
littletomatodonkey's avatar
littletomatodonkey 已提交
110 111 112 113
        for i in range(0, test_num + 10):
            inputs = np.random.rand(args.batch_size, 3, 224,
                                    224).astype(np.float32)
            start_time = time.time()
T
Tingquan Gao 已提交
114
            batch_output = self.predict(inputs).flatten()
littletomatodonkey's avatar
littletomatodonkey 已提交
115 116
            if i >= 10:
                test_time += time.time() - start_time
littletomatodonkey's avatar
littletomatodonkey 已提交
117
            time.sleep(0.01)  # sleep for T4 GPU
S
fix  
shippingwang 已提交
118

littletomatodonkey's avatar
littletomatodonkey 已提交
119
        fp_message = "FP16" if args.use_fp16 else "FP32"
littletomatodonkey's avatar
littletomatodonkey 已提交
120 121
        trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
        print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format(
122 123
            args.model, trt_msg, fp_message, args.batch_size, 1000 * test_time
            / test_num))
W
WuHaobo 已提交
124 125


T
Tingquan Gao 已提交
126 127
if __name__ == "__main__":
    args = parse_args()
128 129 130 131 132 133 134
    assert os.path.exists(
        args.model_file), "The path of 'model_file' does not exist: {}".format(
            args.model_file)
    assert os.path.exists(
        args.params_file
    ), "The path of 'params_file' does not exist: {}".format(args.params_file)

T
Tingquan Gao 已提交
135
    predictor = Predictor(args)
136
    if not args.enable_benchmark:
T
Tingquan Gao 已提交
137
        predictor.normal_predict()
138 139
    else:
        assert args.model is not None
T
Tingquan Gao 已提交
140
        predictor.benchmark_predict()