app.py 3.0 KB
Newer Older
Q
qizhaoaoe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
import numpy as np
import gradio as gr

from download import get_model_path, get_data_path

from plsc.data.preprocess import Resize
from plsc.engine.inference import Predictor

predictor = None


def model_inference(image):
    global predictor

    if predictor is None:

        model_path = "paddlecv://models/swin/v2.5/swin_base_patch4_window7_224_infer.pdmodel"
        params_path = "paddlecv://models/swin/v2.5/swin_base_patch4_window7_224_infer.pdiparams"
        label_path = "paddlecv://dataset/imagenet2012_labels.txt"
        infer_model = get_model_path(model_path)
        infer_params = get_model_path(params_path)

        def parse_labels(label_path):
            labels = []
            with open(label_path, 'r') as f:
                for line in f:
                    if len(line) < 2:
                        continue
                    label = line.strip().split(',')[1]
                    labels.append(label)
            return np.array(labels)

        labels = parse_labels(get_data_path(label_path))

        def preprocess(img):
            resize = Resize(size=224, interpolation="bicubic", backend="pil")
            img = np.array(resize(img))
            scale = 1.0 / 255.0
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = (img * scale - mean) / std
            img = img[np.newaxis, :, :, :]
            img = img.transpose((0, 3, 1, 2))
            return {'x': img.astype('float32')}

        def postprocess(logits):
            def softmax(x, epsilon=1e-6):
                exp_x = np.exp(x)
                sfm = (exp_x + epsilon) / (np.sum(exp_x) + epsilon)
                return sfm

            pred = np.array(logits).squeeze()
            pred = softmax(pred)
            class_idx = pred.argsort()[::-1]
            class_idx_top5 = class_idx[:5]
            return class_idx_top5, pred[class_idx_top5], labels[class_idx_top5]

        predictor = Predictor(
            model_file=infer_model,
            params_file=infer_params,
            preprocess_fn=preprocess,
            postprocess_fn=postprocess)

    class_ids, scores, classes = predictor.predict(image)
    json_out = {
        "class_ids": class_ids.tolist(),
        "scores": scores.tolist(),
        "labels": classes.tolist()
    }
    return image, json_out


def clear_all():
    return None, None, None


with gr.Blocks() as demo:
    gr.Markdown("Classification based on SwinTransformer")

    with gr.Column(scale=1, min_width=100):
        img_in = gr.Image(
            value="https://plsc.bj.bcebos.com/dataset/test_images/zebra.png",
            label="Input").style(height=200)

        with gr.Row():
            btn1 = gr.Button("Clear")
            btn2 = gr.Button("Submit")

        img_out = gr.Image(label="Output").style(height=200)
        json_out = gr.JSON(label="jsonOutput")

    btn2.click(fn=model_inference, inputs=img_in, outputs=[img_out, json_out])
    btn1.click(fn=clear_all, inputs=None, outputs=[img_in, img_out, json_out])
    gr.Button.style(1)

demo.launch()