predict_cls.py 2.7 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))

import cv2
import numpy as np

from utils import logger
from utils import config
from utils.predictor import Predictor
from utils.get_image_list import get_image_list
T
Tingquan Gao 已提交
27 28
from python.preprocess import create_operators
from python.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
29 30


littletomatodonkey's avatar
littletomatodonkey 已提交
31
class ClsPredictor(Predictor):
littletomatodonkey's avatar
littletomatodonkey 已提交
32
    def __init__(self, config):
littletomatodonkey's avatar
littletomatodonkey 已提交
33
        super().__init__(config["Global"])
T
Tingquan Gao 已提交
34 35 36 37 38 39 40 41 42

        self.preprocess_ops = []
        self.postprocess = None
        if "PreProcess" in config:
            if "transform_ops" in config["PreProcess"]:
                self.preprocess_ops = create_operators(config["PreProcess"][
                    "transform_ops"])
        if "PostProcess" in config:
            self.postprocess = build_postprocess(config["PostProcess"])
littletomatodonkey's avatar
littletomatodonkey 已提交
43

littletomatodonkey's avatar
littletomatodonkey 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    def predict(self, images):
        input_names = self.paddle_predictor.get_input_names()
        input_tensor = self.paddle_predictor.get_input_handle(input_names[0])

        output_names = self.paddle_predictor.get_output_names()
        output_tensor = self.paddle_predictor.get_output_handle(output_names[
            0])

        if not isinstance(images, (list, )):
            images = [images]
        for idx in range(len(images)):
            for ops in self.preprocess_ops:
                images[idx] = ops(images[idx])
        image = np.array(images)

        input_tensor.copy_from_cpu(image)
        self.paddle_predictor.run()
        batch_output = output_tensor.copy_to_cpu()
        return batch_output

littletomatodonkey's avatar
littletomatodonkey 已提交
64 65 66 67 68 69 70 71

def main(config):
    cls_predictor = ClsPredictor(config)
    image_list = get_image_list(config["Global"]["infer_imgs"])

    assert config["Global"]["batch_size"] == 1
    for idx, image_file in enumerate(image_list):
        img = cv2.imread(image_file)[:, :, ::-1]
littletomatodonkey's avatar
littletomatodonkey 已提交
72
        output = cls_predictor.predict(img)
littletomatodonkey's avatar
littletomatodonkey 已提交
73 74 75 76 77 78 79 80 81
        output = cls_predictor.postprocess(output)
        print(output)
    return


if __name__ == "__main__":
    args = config.parse_args()
    config = config.get_config(args.config, overrides=args.override, show=True)
    main(config)