test_hubserving.py 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))

from ppcls.utils import logger
import cv2
import time
import requests
import json
import base64
import imghdr


def get_image_file_list(img_file):
    imgs_lists = []
    if img_file is None or not os.path.exists(img_file):
        raise Exception("not found any img file in {}".format(img_file))

    img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
    if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
        imgs_lists.append(img_file)
    elif os.path.isdir(img_file):
        for single_file in os.listdir(img_file):
            file_path = os.path.join(img_file, single_file)
            if imghdr.what(file_path) in img_end:
                imgs_lists.append(file_path)
    if len(imgs_lists) == 0:
        raise Exception("not found any img file in {}".format(img_file))
    return imgs_lists


def cv2_to_base64(image):
    return base64.b64encode(image).decode('utf8')


def main(url, image_path, top_k=1):
    image_file_list = get_image_file_list(image_path)
    headers = {"Content-type": "application/json"}
    cnt = 0
    total_time = 0
    all_acc = 0.0

    for image_file in image_file_list:
60
        file_str = image_file.split('/')[-1]
61 62
        img = open(image_file, 'rb').read()
        if img is None:
63
            logger.error("Loading image:{} failed".format(image_file))
64 65 66
            continue
        data = {'images': [cv2_to_base64(img)], 'top_k': top_k}

67 68 69 70 71 72
        try:
            r = requests.post(url=url, headers=headers, data=json.dumps(data))
            r.raise_for_status()
        except Exception as e:
            logger.error("File:{}, {}".format(file_str, e))
            continue
T
Tingquan Gao 已提交
73
        if r.json()['status'] != '000':
74 75 76 77
            logger.error(
                "File:{}, The parameters returned by the server are: {}".
                format(file_str, r.json()['msg']))
            continue
78
        res = r.json()["results"][0]
T
Tingquan Gao 已提交
79
        classes, scores, elapse = res
80
        all_acc += scores[0]
T
Tingquan Gao 已提交
81
        total_time += elapse
82 83 84 85 86 87 88 89 90 91
        cnt += 1

        scores = map(lambda x: round(x, 5), scores)
        results = dict(zip(classes, scores))

        message = "No.{}, File:{}, The top-{} result(s):{}, Time cost:{:.3f}".format(
            cnt, file_str, top_k, results, elapse)
        logger.info(message)

    logger.info("The average time cost: {}".format(float(total_time) / cnt))
L
littletomatodonkey 已提交
92
    logger.info("The average top-1 score: {}".format(float(all_acc) / cnt))
93 94 95 96 97 98 99 100 101 102


if __name__ == '__main__':
    if len(sys.argv) != 3 and len(sys.argv) != 4:
        logger.info("Usage: %s server_url image_path" % sys.argv[0])
    else:
        server_url = sys.argv[1]
        image_path = sys.argv[2]
        top_k = int(sys.argv[3]) if len(sys.argv) == 4 else 1
        main(server_url, image_path, top_k)