From 5cf3c3bbf7b987a627a197f3c808570a23e86498 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 17 Jan 2023 12:05:08 +0800 Subject: [PATCH] fix cache directory and app style (#5705) LGTM --- modelcenter/PLSC-ViT/APP/app.py | 17 +++++++++++++---- modelcenter/PLSC-ViT/APP/download.py | 13 ++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/modelcenter/PLSC-ViT/APP/app.py b/modelcenter/PLSC-ViT/APP/app.py index 8520015b..89e38d59 100644 --- a/modelcenter/PLSC-ViT/APP/app.py +++ b/modelcenter/PLSC-ViT/APP/app.py @@ -1,13 +1,20 @@ import gradio as gr from predictor import Predictor - model_path = "paddlecv://models/vit/v2.4/imagenet2012-ViT-B_16-224_infer.pdmodel" params_path = "paddlecv://models/vit/v2.4/imagenet2012-ViT-B_16-224_infer.pdiparams" label_path = "paddlecv://dataset/imagenet2012_labels.txt" +predictor = None + + def model_inference(image): - predictor = Predictor(model_path=model_path, params_path=params_path, label_path=label_path) + global predictor + if predictor is None: + predictor = Predictor( + model_path=model_path, + params_path=params_path, + label_path=label_path) scores, labels = predictor.predict(image) json_out = {"scores": scores.tolist(), "labels": labels.tolist()} return image, json_out @@ -22,13 +29,15 @@ with gr.Blocks() as demo: with gr.Column(scale=1, min_width=100): - img_in = gr.Image(value="https://plsc.bj.bcebos.com/dataset/test_images/cat.jpg",label="Input").style(height=200) + img_in = gr.Image( + value="https://plsc.bj.bcebos.com/dataset/test_images/cat.jpg", + label="Input") with gr.Row(): btn1 = gr.Button("Clear") btn2 = gr.Button("Submit") - img_out = gr.Image(label="Output").style(height=200) + img_out = gr.Image(label="Output") json_out = gr.JSON(label="jsonOutput") btn2.click(fn=model_inference, inputs=img_in, outputs=[img_out, json_out]) diff --git a/modelcenter/PLSC-ViT/APP/download.py b/modelcenter/PLSC-ViT/APP/download.py index 9fb8ba64..648d568e 100644 --- a/modelcenter/PLSC-ViT/APP/download.py +++ b/modelcenter/PLSC-ViT/APP/download.py @@ -33,10 +33,10 @@ __all__ = [ 'get_data_path', ] -WEIGHTS_HOME = osp.expanduser("~/.cache/paddlecv/models") -CONFIGS_HOME = osp.expanduser("~/.cache/paddlecv/configs") -DICTS_HOME = osp.expanduser("~/.cache/paddlecv/dicts") -DATA_HOME = osp.expanduser("~/.cache/paddlecv/dataset") +WEIGHTS_HOME = osp.expanduser("~/.cache/paddlecv/models/plsc") +CONFIGS_HOME = osp.expanduser("~/.cache/paddlecv/configs/plsc") +DICTS_HOME = osp.expanduser("~/.cache/paddlecv/dicts/plsc/") +DATA_HOME = osp.expanduser("~/.cache/paddlecv/dataset/plsc") # dict of {dataset_name: (download_info, sub_dirs)} # download info: [(url, md5sum)] @@ -68,7 +68,7 @@ def get_model_path(path): if not is_url(path): return path url = parse_url(path) - path, _ = get_path(url, WEIGHTS_HOME, path_depth=2) + path, _ = get_path(url, WEIGHTS_HOME, path_depth=3) return path @@ -79,7 +79,7 @@ def get_data_path(path): if not is_url(path): return path url = parse_url(path) - path, _ = get_path(url, DATA_HOME, path_depth=2) + path, _ = get_path(url, DATA_HOME, path_depth=1) return path @@ -162,7 +162,6 @@ def _download(url, path, md5sum=None): raise RuntimeError("Download from {} failed. " "Retry limit reached".format(url)) - # NOTE: windows path join may incur \, which is invalid in url if sys.platform == "win32": url = url.replace('\\', '/') -- GitLab