model_list.py 3.4 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
from collections import OrderedDict


class ServingModels(object):
    def __init__(self):
        self.model_dict = OrderedDict()
D
dongdaxiang 已提交
23 24
        self.model_dict[
            "SentimentAnalysis"] = ["senta_bilstm", "senta_bow", "senta_cnn"]
M
fix app  
MRXLT 已提交
25
        self.model_dict["SemanticRepresentation"] = ["ernie"]
D
dongdaxiang 已提交
26
        self.model_dict["ChineseWordSegmentation"] = ["lac"]
M
MRXLT 已提交
27
        self.model_dict["ObjectDetection"] = ["faster_rcnn", "yolov4"]
28
        self.model_dict["ImageSegmentation"] = [
J
Jiawei Wang 已提交
29
            "unet", "deeplabv3", "deeplabv3+cityscapes"
30
        ]
D
dongdaxiang 已提交
31
        self.model_dict["ImageClassification"] = [
D
dongdaxiang 已提交
32
            "resnet_v2_50_imagenet", "mobilenet_v2_imagenet"
D
dongdaxiang 已提交
33
        ]
D
dongdaxiang 已提交
34
        self.model_dict["TextDetection"] = ["ocr_detection"]
M
MRXLT 已提交
35
        self.model_dict["OCR"] = ["ocr_rec"]
D
dongdaxiang 已提交
36 37 38 39

        image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/"
        image_seg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageSegmentation/"
        object_detection_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ObjectDetection/"
M
MRXLT 已提交
40
        ocr_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/"
D
dongdaxiang 已提交
41
        senta_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/"
M
MRXLT 已提交
42
        semantic_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticModel/"
M
MRXLT 已提交
43
        wordseg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/LexicalAnalysis/"
D
dongdaxiang 已提交
44
        ocr_det_url = "https://paddle-serving.bj.bcebos.com/ocr/"
D
dongdaxiang 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58

        self.url_dict = {}

        def pack_url(model_dict, key, url):
            for i, value in enumerate(model_dict[key]):
                self.url_dict[model_dict[key][i]] = url + model_dict[key][
                    i] + ".tar.gz"

        pack_url(self.model_dict, "SentimentAnalysis", senta_url)
        pack_url(self.model_dict, "SemanticRepresentation", semantic_url)
        pack_url(self.model_dict, "ChineseWordSegmentation", wordseg_url)
        pack_url(self.model_dict, "ObjectDetection", object_detection_url)
        pack_url(self.model_dict, "ImageSegmentation", image_seg_url)
        pack_url(self.model_dict, "ImageClassification", image_class_url)
M
MRXLT 已提交
59
        pack_url(self.model_dict, "OCR", ocr_url)
D
dongdaxiang 已提交
60
        pack_url(self.model_dict, "TextDetection", ocr_det_url)
M
MRXLT 已提交
61 62

    def get_model_list(self):
D
dongdaxiang 已提交
63
        return self.model_dict
M
MRXLT 已提交
64 65

    def download(self, model_name):
D
dongdaxiang 已提交
66 67
        if model_name in self.url_dict:
            url = self.url_dict[model_name]
M
MRXLT 已提交
68 69
            r = os.system('wget ' + url + ' --no-check-certificate')

D
dongdaxiang 已提交
70 71 72 73
    def get_tutorial(self, model_name):
        if model_name in self.tutorial_url:
            return "Tutorial of {} to be added".format(model_name)

M
MRXLT 已提交
74 75 76 77

if __name__ == "__main__":
    models = ServingModels()
    print(models.get_model_list())