app.py 4.9 KB
Newer Older
H
HydrogenSulfate 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
import base64
import os
from io import BytesIO
from typing import Any, Dict, List, Union

import gradio as gr
import numpy as np
import requests
from paddleclas import PaddleClas
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm


def download_with_progressbar(url: str, save_path: str):
    """Download file from given url and decompress it

    Args:
        url (str): url
        save_path (str): path for saving downloaded file

    Raises:
        Exception: exception
    """
    print(f"Auto downloading {url} to {save_path}")
    if os.path.exists(save_path):
        print("File already exist, skip...")
    else:
        response = requests.get(url, stream=True)
        total_size_in_bytes = int(response.headers.get("content-length", 0))
        block_size = 1024  # 1 Kibibyte
        progress_bar = tqdm(
            total=total_size_in_bytes, unit="iB", unit_scale=True)
        with open(save_path, "wb") as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()
        if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes or not os.path.isfile(
                save_path):
            raise Exception(
                f"Something went wrong while downloading file from {url}")
        print("Finished downloading")
        print(f"Try decompression at {save_path}")
        os.system(f"tar -xf {save_path}")
        print(f"Finished decompression at {save_path}")


def image_to_base64(image: Image.Image) -> str:
    """encode Pillow image to base64 string

    Args:
        image (Image.Image): image to be encoded

    Returns:
        str: encoded string
    """
    byte_data = BytesIO()  # 创建一个字节流管道
    image.save(byte_data, format="JPEG")  # 将图片数据存入字节流管道
    byte_data = byte_data.getvalue()  # 从字节流管道中获取二进制
    base64_str = base64.b64encode(byte_data).decode("ascii")  # 二进制转base64
    return base64_str


# UGC: Define the inference fn() for your models
def model_inference(image) -> tuple:
    """send given image to inference model and get result from output

    Args:
        image (gr.Image): input image

    Returns:
        tuple: (drawn image to display, result in json format)
    """
    results = clas_engine.predict(image, print_pred=True, predict_type="shitu")

    # bs = 1, fetch the first result
    results = list(results)[0]

    image_draw_box = draw_bbox_results(image, results)

    im_show = Image.fromarray(image_draw_box)

    json_out = {"base64": image_to_base64(im_show), "result": str(results)}
    return im_show, json_out


def draw_bbox_results(image: Union[np.ndarray, Image.Image],
                      results: List[Dict[str, Any]]) -> np.ndarray:
    """draw bounding box(es)

    Args:
        image (Union[np.ndarray, Image.Image]): image to be drawn
        results (List[Dict[str, Any]]): information for drawing bounding box

    Returns:
        np.ndarray: drawn image
    """
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    draw = ImageDraw.Draw(image)
    font_size = 18
    font = ImageFont.truetype("./simfang.ttf", font_size, encoding="utf-8")

    color = (0, 102, 255)

    for result in results:
        # empty results
        if result["rec_docs"] is None:
            continue

        xmin, ymin, xmax, ymax = result["bbox"]
        text = "{}, {:.2f}".format(result["rec_docs"], result["rec_scores"])
        th = font_size
        tw = font.getsize(text)[0]
        start_y = max(0, ymin - th)

        draw.rectangle(
            [(xmin + 1, start_y), (xmin + tw + 1, start_y + th)], fill=color)

        draw.text((xmin + 1, start_y), text, fill=(255, 255, 255), font=font)

        draw.rectangle(
            [(xmin, ymin), (xmax, ymax)], outline=(255, 0, 0), width=2)

    return np.array(image)


def clear_all():
    return None, None, None


# download drink_dataset_v2.0.tar
dataset_url = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar"
download_with_progressbar(dataset_url,
                          os.path.join("./", dataset_url.split("/")[-1]))

clas_engine = PaddleClas(model_name="PP-ShiTuV2", use_gpu=False)

with gr.Blocks() as demo:
    gr.Markdown("PP-ShiTuV2")

    with gr.Column(scale=1, min_width=100):
        img_in = gr.Image(
E
Evezerest 已提交
144
            value="https://gitee.com/PaddlePaddle/PaddleClas/raw/release/2.5/docs/images/recognition/drink_data_demo/test_images/100.jpeg",
H
HydrogenSulfate 已提交
145 146 147 148 149 150 151 152 153 154 155 156
            label="Input")

        with gr.Row():
            btn1 = gr.Button("Clear")
            btn2 = gr.Button("Submit")
        img_out = gr.Image(label="Output").style(height=400)
        json_out = gr.JSON(label="jsonOutput")

    btn2.click(fn=model_inference, inputs=img_in, outputs=[img_out, json_out])
    btn1.click(fn=clear_all, inputs=None, outputs=[img_in, img_out, json_out])
    gr.Button.style(1)

157
demo.launch()