From cf279dfbbdaf69ee882958e003e9bcb6f32c85d7 Mon Sep 17 00:00:00 2001 From: MRXLT Date: Thu, 23 Apr 2020 06:57:29 +0000 Subject: [PATCH] refine app --- .../paddle_serving_app/models/model_list.py | 46 ++++++++ .../paddle_serving_app/reader/image_reader.py | 107 ++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 python/paddle_serving_app/models/model_list.py create mode 100644 python/paddle_serving_app/reader/image_reader.py diff --git a/python/paddle_serving_app/models/model_list.py b/python/paddle_serving_app/models/model_list.py new file mode 100644 index 00000000..73306449 --- /dev/null +++ b/python/paddle_serving_app/models/model_list.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from collections import OrderedDict + + +class ServingModels(object): + def __init__(self): + self.model_dict = OrderedDict() + #senta + for key in [ + "senta_bilstm", "senta_bow", "senta_cnn", "senta_gru", + "senta_lstm" + ]: + self.model_dict[ + key] = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/" + key + ".tar.gz" + #image classification + for key in ["alexnet_imagenet"]: + self.model_dict[ + key] = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/" + key + ".tar.gz" + + def get_model_list(self): + return (self.model_dict.keys()) + + def download(self, model_name): + if model_name in self.model_dict: + url = self.model_dict[model_name] + r = os.system('wget ' + url + ' --no-check-certificate') + + +if __name__ == "__main__": + models = ServingModels() + print(models.get_model_list()) diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py new file mode 100644 index 00000000..2647eb6f --- /dev/null +++ b/python/paddle_serving_app/reader/image_reader.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +class ImageReader(): + def __init__(self, + image_shape=[3, 224, 224], + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + resize_short_size=256, + interpolation=None, + crop_center=True): + self.image_mean = image_mean + self.image_std = image_std + self.image_shape = image_shape + self.resize_short_size = resize_short_size + self.interpolation = interpolation + self.crop_center = crop_center + + def resize_short(self, img, target_size, interpolation=None): + """resize image + + Args: + img: image data + target_size: resize short target size + interpolation: interpolation mode + + Returns: + resized image data + """ + percent = float(target_size) / min(img.shape[0], img.shape[1]) + resized_width = int(round(img.shape[1] * percent)) + resized_height = int(round(img.shape[0] * percent)) + if interpolation: + resized = cv2.resize( + img, (resized_width, resized_height), + interpolation=interpolation) + else: + resized = cv2.resize(img, (resized_width, resized_height)) + return resized + + def crop_image(self, img, target_size, center): + """crop image + + Args: + img: images data + target_size: crop target size + center: crop mode + + Returns: + img: cropped image data + """ + height, width = img.shape[:2] + size = target_size + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img[h_start:h_end, w_start:w_end, :] + return img + + def process_image(self, sample): + """ process_image """ + mean = self.image_mean + std = self.image_std + crop_size = self.image_shape[1] + + data = np.fromstring(sample, np.uint8) + img = cv2.imdecode(data, cv2.IMREAD_COLOR) + + if img is None: + print("img is None, pass it.") + return None + + if crop_size > 0: + target_size = self.resize_short_size + img = self.resize_short( + img, target_size, interpolation=self.interpolation) + img = self.crop_image( + img, target_size=crop_size, center=self.crop_center) + + img = img[:, :, ::-1] + + img = img.astype('float32').transpose((2, 0, 1)) / 255 + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + return img -- GitLab