diff --git a/core/sdk-cpp/include/endpoint_config.h b/core/sdk-cpp/include/endpoint_config.h index 6edb6ed8ab3b7c62be12c35c1658a2adf3140341..f814b659a24c4e5cb6e352c9a39bcfef1df147c1 100644 --- a/core/sdk-cpp/include/endpoint_config.h +++ b/core/sdk-cpp/include/endpoint_config.h @@ -22,23 +22,23 @@ namespace baidu { namespace paddle_serving { namespace sdk_cpp { -#define PARSE_CONF_ITEM(conf, item, name, fail) \ - do { \ - if (conf.has_##name()) { \ - item.set(conf.name()); \ - } else { \ - LOG(ERROR) << "Not found key in configue: " << #name; \ - } \ +#define PARSE_CONF_ITEM(conf, item, name, fail) \ + do { \ + if (conf.has_##name()) { \ + item.set(conf.name()); \ + } else { \ + VLOG(2) << "Not found key in configue: " << #name; \ + } \ } while (0) -#define ASSIGN_CONF_ITEM(dest, src, fail) \ - do { \ - if (!src.init) { \ - LOG(ERROR) << "Cannot assign an unintialized item: " << #src \ - << " to dest: " << #dest; \ - return fail; \ - } \ - dest = src.value; \ +#define ASSIGN_CONF_ITEM(dest, src, fail) \ + do { \ + if (!src.init) { \ + VLOG(2) << "Cannot assign an unintialized item: " << #src \ + << " to dest: " << #dest; \ + return fail; \ + } \ + dest = src.value; \ } while (0) template diff --git a/python/examples/deeplabv3/N0060.jpg b/python/examples/deeplabv3/N0060.jpg new file mode 100644 index 0000000000000000000000000000000000000000..feac2837eaa5ae5db414d9769a0c5a830dde268d Binary files /dev/null and b/python/examples/deeplabv3/N0060.jpg differ diff --git a/python/examples/deeplabv3/deeplabv3_client.py b/python/examples/deeplabv3/deeplabv3_client.py new file mode 100644 index 0000000000000000000000000000000000000000..75ea6b0a01868af30c94fb0686159571c2c1c966 --- /dev/null +++ b/python/examples/deeplabv3/deeplabv3_client.py @@ -0,0 +1,34 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, Resize, Transpose, BGR2RGB, SegPostprocess +import sys +import cv2 + +client = Client() +client.load_client_config("seg_client/serving_client_conf.prototxt") +client.connect(["127.0.0.1:9494"]) + +preprocess = Sequential( + [File2Image(), Resize( + (512, 512), interpolation=cv2.INTER_LINEAR)]) + +postprocess = SegPostprocess(2) + +filename = "N0060.jpg" +im = preprocess(filename) +fetch_map = client.predict(feed={"image": im}, fetch=["output"]) +fetch_map["filename"] = filename +postprocess(fetch_map) diff --git a/python/examples/faster_rcnn_model/label_list.txt b/python/examples/faster_rcnn_model/label_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..d7d43a94adf73208f997f0efd6581bef11ca734e --- /dev/null +++ b/python/examples/faster_rcnn_model/label_list.txt @@ -0,0 +1,81 @@ +background +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +dining table +toilet +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/python/examples/faster_rcnn_model/new_test_client.py b/python/examples/faster_rcnn_model/new_test_client.py new file mode 100755 index 0000000000000000000000000000000000000000..283df71689107983f98833af244181ad1be7f99c --- /dev/null +++ b/python/examples/faster_rcnn_model/new_test_client.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +import sys +from paddle_serving_app.reader.pddet import Detection +from paddle_serving_app.reader import File2Image, Sequential, Normalize, Resize, Transpose, Div, BGR2RGB, RCNNPostprocess +import numpy as np + +preprocess = Sequential([ + File2Image(), BGR2RGB(), Div(255.0), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False), + Resize(640, 640), Transpose((2, 0, 1)) +]) + +postprocess = RCNNPostprocess("label_list.txt", "output") + +client = Client() +client.load_client_config(sys.argv[1]) +client.connect(['127.0.0.1:9393']) + +for i in range(100): + im = preprocess(sys.argv[2]) + fetch_map = client.predict( + feed={ + "image": im, + "im_info": np.array(list(im.shape[1:]) + [1.0]), + "im_shape": np.array(list(im.shape[1:]) + [1.0]) + }, + fetch=["multiclass_nms"]) + fetch_map["image"] = sys.argv[2] + postprocess(fetch_map) diff --git a/python/examples/imagenet/image_rpc_client.py b/python/examples/imagenet/image_rpc_client.py index f905179629f0dfc8c9da09b0cae90bae7be3687e..4d74d2ed26a757a6f7978d8071286d3d4bcd5dfb 100644 --- a/python/examples/imagenet/image_rpc_client.py +++ b/python/examples/imagenet/image_rpc_client.py @@ -13,22 +13,24 @@ # limitations under the License. import sys -from image_reader import ImageReader from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize import time client = Client() client.load_client_config(sys.argv[1]) client.connect(["127.0.0.1:9393"]) -reader = ImageReader() + +seq = Sequential([ + File2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)), + Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) +print(seq) start = time.time() +image_file = "daisy.jpg" for i in range(1000): - with open("./data/n01440764_10026.JPEG", "rb") as f: - img = f.read() - img = reader.process_image(img) + img = seq(image_file) fetch_map = client.predict(feed={"image": img}, fetch=["score"]) end = time.time() print(end - start) - -#print(fetch_map["score"]) diff --git a/python/examples/mobilenet/daisy.jpg b/python/examples/mobilenet/daisy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7edeca63e5f32e68550ef720d81f59df58a8eabc Binary files /dev/null and b/python/examples/mobilenet/daisy.jpg differ diff --git a/python/examples/mobilenet/mobilenet_tutorial.py b/python/examples/mobilenet/mobilenet_tutorial.py new file mode 100644 index 0000000000000000000000000000000000000000..9550a5ff705d23d3f6a97d8498d5a8b1e4f152b7 --- /dev/null +++ b/python/examples/mobilenet/mobilenet_tutorial.py @@ -0,0 +1,32 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, Resize +from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize + +client = Client() +client.load_client_config( + "mobilenet_v2_imagenet_client/serving_client_conf.prototxt") +client.connect(["127.0.0.1:9393"]) + +seq = Sequential([ + File2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)), + Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True) +]) + +image_file = "daisy.jpg" +img = seq(image_file) +fetch_map = client.predict(feed={"image": img}, fetch=["feature_map"]) +print(fetch_map["feature_map"].reshape(-1)) diff --git a/python/examples/resnet_v2_50/daisy.jpg b/python/examples/resnet_v2_50/daisy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7edeca63e5f32e68550ef720d81f59df58a8eabc Binary files /dev/null and b/python/examples/resnet_v2_50/daisy.jpg differ diff --git a/python/examples/resnet_v2_50/resnet50_v2_tutorial.py b/python/examples/resnet_v2_50/resnet50_v2_tutorial.py new file mode 100644 index 0000000000000000000000000000000000000000..8d916cbd8145cdc73424a05fdb2855412f4d4fe2 --- /dev/null +++ b/python/examples/resnet_v2_50/resnet50_v2_tutorial.py @@ -0,0 +1,32 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, Resize, CenterCrop +from apddle_serving_app.reader import RGB2BGR, Transpose, Div, Normalize + +client = Client() +client.load_client_config( + "resnet_v2_50_imagenet_client/serving_client_conf.prototxt") +client.connect(["127.0.0.1:9393"]) + +seq = Sequential([ + File2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)), + Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True) +]) + +image_file = "daisy.jpg" +img = seq(image_file) +fetch_map = client.predict(feed={"image": img}, fetch=["feature_map"]) +print(fetch_map["feature_map"].reshape(-1)) diff --git a/python/examples/unet_for_image_seg/N0060.jpg b/python/examples/unet_for_image_seg/N0060.jpg new file mode 100644 index 0000000000000000000000000000000000000000..feac2837eaa5ae5db414d9769a0c5a830dde268d Binary files /dev/null and b/python/examples/unet_for_image_seg/N0060.jpg differ diff --git a/python/examples/unet_for_image_seg/seg_client.py b/python/examples/unet_for_image_seg/seg_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9e76b060955ec74492312c8896efaf3946a3f7ab --- /dev/null +++ b/python/examples/unet_for_image_seg/seg_client.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, Resize, Transpose, BGR2RGB, SegPostprocess +import sys +import cv2 + +client = Client() +client.load_client_config("unet_client/serving_client_conf.prototxt") +client.connect(["127.0.0.1:9494"]) + +preprocess = Sequential( + [File2Image(), Resize( + (512, 512), interpolation=cv2.INTER_LINEAR)]) + +postprocess = SegPostprocess(2) + +im = preprocess("N0060.jpg") +fetch_map = client.predict(feed={"image": im}, fetch=["output"]) +fetch_map["filename"] = filename +postprocess(fetch_map) diff --git a/python/paddle_serving_app/__init__.py b/python/paddle_serving_app/__init__.py index 3db901249df41f5d9cd5846d131ec6cfed376a18..860876030695baee15d3ace68c6af386290cfbb0 100644 --- a/python/paddle_serving_app/__init__.py +++ b/python/paddle_serving_app/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .reader.chinese_bert_reader import ChineseBertReader -from .reader.image_reader import ImageReader +from .reader.image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize from .reader.lac_reader import LACReader from .reader.senta_reader import SentaReader from .models import ServingModels diff --git a/python/paddle_serving_app/models/model_list.py b/python/paddle_serving_app/models/model_list.py index 6709c8aea06c3fa0cce2acdc0cbaf7d4a9c9c64e..a2019997968ce21a30669b2acd1421355b1e0fdd 100644 --- a/python/paddle_serving_app/models/model_list.py +++ b/python/paddle_serving_app/models/model_list.py @@ -20,78 +20,49 @@ from collections import OrderedDict class ServingModels(object): def __init__(self): self.model_dict = OrderedDict() - #senta - for key in [ - "senta_bilstm", "senta_bow", "senta_cnn", "senta_gru", - "senta_lstm" - ]: - self.model_dict[ - key] = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/" + key + ".tar.gz" - #image classification - for key in [ - "alexnet_imagenet", - "darknet53-imagenet", - "densenet121_imagenet", - "densenet161_imagenet", - "densenet169_imagenet", - "densenet201_imagenet", - "densenet264_imagenet" - "dpn107_imagenet", - "dpn131_imagenet", - "dpn68_imagenet", - "dpn92_imagenet", - "dpn98_imagenet", - "efficientnetb0_imagenet", - "efficientnetb1_imagenet", - "efficientnetb2_imagenet", - "efficientnetb3_imagenet", - "efficientnetb4_imagenet", - "efficientnetb5_imagenet", - "efficientnetb6_imagenet", - "googlenet_imagenet", - "inception_v4_imagenet", - "inception_v2_imagenet", - "nasnet_imagenet", - "pnasnet_imagenet", - "resnet_v2_101_imagenet", - "resnet_v2_151_imagenet", - "resnet_v2_18_imagenet", - "resnet_v2_34_imagenet", - "resnet_v2_50_imagenet", - "resnext101_32x16d_wsl", - "resnext101_32x32d_wsl", - "resnext101_32x48d_wsl", - "resnext101_32x8d_wsl", - "resnext101_32x4d_imagenet", - "resnext101_64x4d_imagenet", - "resnext101_vd_32x4d_imagenet", - "resnext101_vd_64x4d_imagenet", - "resnext152_64x4d_imagenet", - "resnext152_vd_64x4d_imagenet", - "resnext50_64x4d_imagenet", - "resnext50_vd_32x4d_imagenet", - "resnext50_vd_64x4d_imagenet", - "se_resnext101_32x4d_imagenet", - "se_resnext50_32x4d_imagenet", - "shufflenet_v2_imagenet", - "vgg11_imagenet", - "vgg13_imagenet", - "vgg16_imagenet", - "vgg19_imagenet", - "xception65_imagenet", - "xception71_imagenet", - ]: - self.model_dict[ - key] = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/" + key + ".tar.gz" + self.model_dict[ + "SentimentAnalysis"] = ["senta_bilstm", "senta_bow", "senta_cnn"] + self.model_dict["SemanticRepresentation"] = ["ernie_base"] + self.model_dict["ChineseWordSegmentation"] = ["lac"] + self.model_dict["ObjectDetection"] = ["faster_rcnn", "yolov3"] + self.model_dict["ImageSegmentation"] = ["unet", "deeplabv3"] + self.model_dict["ImageClassification"] = [ + "resnet_v2_50_imagenet", "mobilenet_v2_imagenet" + ] + + image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/" + image_seg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageSegmentation/" + object_detection_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ObjectDetection/" + senta_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/" + semantic_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticRepresentation/" + wordseg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/ChineseWordSegmentation/" + + self.url_dict = {} + + def pack_url(model_dict, key, url): + for i, value in enumerate(model_dict[key]): + self.url_dict[model_dict[key][i]] = url + model_dict[key][ + i] + ".tar.gz" + + pack_url(self.model_dict, "SentimentAnalysis", senta_url) + pack_url(self.model_dict, "SemanticRepresentation", semantic_url) + pack_url(self.model_dict, "ChineseWordSegmentation", wordseg_url) + pack_url(self.model_dict, "ObjectDetection", object_detection_url) + pack_url(self.model_dict, "ImageSegmentation", image_seg_url) + pack_url(self.model_dict, "ImageClassification", image_class_url) def get_model_list(self): - return (self.model_dict.keys()) + return self.model_dict def download(self, model_name): - if model_name in self.model_dict: - url = self.model_dict[model_name] + if model_name in self.url_dict: + url = self.url_dict[model_name] r = os.system('wget ' + url + ' --no-check-certificate') + def get_tutorial(self, model_name): + if model_name in self.tutorial_url: + return "Tutorial of {} to be added".format(model_name) + if __name__ == "__main__": models = ServingModels() diff --git a/python/paddle_serving_app/package.py b/python/paddle_serving_app/package.py index 98e42f365397e6ecae5171c47eb1cfabee182a7d..e27914931d4f64c98627cd54025fcf87ac0f241d 100644 --- a/python/paddle_serving_app/package.py +++ b/python/paddle_serving_app/package.py @@ -20,6 +20,7 @@ Usage: """ import argparse +import sys from .models import ServingModels @@ -29,6 +30,8 @@ def parse_args(): # pylint: disable=doc-string-missing "--get_model", type=str, default="", help="Download a specific model") parser.add_argument( '--list_model', nargs='*', default=None, help="List Models") + parser.add_argument( + '--tutorial', type=str, default="", help="Get running command") return parser.parse_args() @@ -36,18 +39,33 @@ if __name__ == "__main__": args = parse_args() if args.list_model != None: model_handle = ServingModels() - model_names = model_handle.get_model_list() - for key in model_names: - print(key) + model_dict = model_handle.get_model_list() + # Task level model list + # Text Classification, Semantic Representation + # Image Classification, Object Detection, Image Segmentation + for key in model_dict: + print("-----------------------------------------------") + print("{}: {}".format(key, " | ".join(model_dict[key]))) + elif args.get_model != "": model_handle = ServingModels() - model_names = model_handle.get_model_list() - if args.get_model not in model_names: + model_dict = model_handle.url_dict + if args.get_model not in model_dict: print( "Your model name does not exist in current model list, stay tuned" ) sys.exit(0) model_handle.download(args.get_model) + elif args.tutorial != "": + model_handle = ServingModels() + model_dict = model_handle.url_dict + if args.get_model not in model_dict: + print( + "Your model name does not exist in current model list, stay tuned" + ) + sys.exit(0) + tutorial_str = model_handle.get_tutorial() + print(tutorial_str) else: print("Wrong argument") print(""" diff --git a/python/paddle_serving_app/reader/__init__.py b/python/paddle_serving_app/reader/__init__.py index 847ddc47ac89114f2012bc6b9990a69abfe39fb3..01cad9e6bbdbe11191e3bc44ec2c63f2db3939bc 100644 --- a/python/paddle_serving_app/reader/__init__.py +++ b/python/paddle_serving_app/reader/__init__.py @@ -11,3 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, RCNNPostprocess, SegPostprocess diff --git a/python/paddle_serving_app/reader/daisy.jpg b/python/paddle_serving_app/reader/daisy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7edeca63e5f32e68550ef720d81f59df58a8eabc Binary files /dev/null and b/python/paddle_serving_app/reader/daisy.jpg differ diff --git a/python/paddle_serving_app/reader/functional.py b/python/paddle_serving_app/reader/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..4240641dd99fceb278ff60a5ba1dbb5275e534aa --- /dev/null +++ b/python/paddle_serving_app/reader/functional.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def transpose(img, transpose_target): + img = img.transpose(transpose_target) + return img + + +def normalize(img, mean, std, channel_first): + # need to optimize here + if channel_first: + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + else: + img_mean = np.array(mean).reshape((1, 1, 3)) + img_std = np.array(std).reshape((1, 1, 3)) + img -= img_mean + img /= img_std + return img + + +def crop(img, target_size, center): + height, width = img.shape[:2] + size = target_size + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img[h_start:h_end, w_start:w_end, :] + return img + + +def resize(img, target_size, max_size=2147483647, interpolation=None): + if isinstance(target_size, tuple): + resized_width = min(target_size[0], max_size) + resized_height = min(target_size[1], max_size) + else: + im_max_size = max(img.shape[0], img.shape[1]) + percent = float(target_size) / min(img.shape[0], img.shape[1]) + if np.round(percent * im_max_size) > max_size: + percent = float(max_size) / float(im_max_size) + resized_width = int(round(img.shape[1] * percent)) + resized_height = int(round(img.shape[0] * percent)) + if interpolation: + resized = cv2.resize( + img, (resized_width, resized_height), interpolation=interpolation) + else: + resized = cv2.resize(img, (resized_width, resized_height)) + return resized diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py index 2647eb6fdf3ca0f1682ca794051b9d0dd95a9a07..b0abd21437a96127e5627f4e26114b72e4c1d891 100644 --- a/python/paddle_serving_app/reader/image_reader.py +++ b/python/paddle_serving_app/reader/image_reader.py @@ -11,9 +11,472 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import cv2 +import os +import urllib import numpy as np +import base64 +import functional as F +from PIL import Image, ImageDraw +import json + +_cv2_interpolation_to_str = {cv2.INTER_LINEAR: "cv2.INTER_LINEAR", None: "None"} + + +def generate_colormap(num_classes): + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +class SegPostprocess(object): + def __init__(self, class_num): + self.class_num = class_num + + def __call__(self, image_with_result): + if "filename" not in image_with_result: + raise ("filename should be specified in postprocess") + img_name = image_with_result["filename"] + ori_img = cv2.imread(img_name, -1) + ori_shape = ori_img.shape + mask = None + for key in image_with_result: + if ".lod" in key or "filename" in key: + continue + mask = image_with_result[key] + if mask is None: + raise ("segment mask should be specified in postprocess") + mask = mask.astype("uint8") + mask_png = mask.reshape((512, 512, 1)) + #score_png = mask_png[:, :, np.newaxis] + score_png = mask_png + score_png = np.concatenate([score_png] * 3, axis=2) + color_map = generate_colormap(self.class_num) + for i in range(score_png.shape[0]): + for j in range(score_png.shape[1]): + score_png[i, j] = color_map[score_png[i, j, 0]] + ext_pos = img_name.rfind(".") + img_name_fix = img_name[:ext_pos] + "_" + img_name[ext_pos + 1:] + mask_save_name = img_name_fix + "_mask.png" + cv2.imwrite(mask_save_name, mask_png, [cv2.CV_8UC1]) + vis_result_name = img_name_fix + "_result.png" + result_png = score_png + + result_png = cv2.resize( + result_png, + ori_shape[:2], + fx=0, + fy=0, + interpolation=cv2.INTER_CUBIC) + cv2.imwrite(vis_result_name, result_png, [cv2.CV_8UC1]) + + +class RCNNPostprocess(object): + def __init__(self, label_file, output_dir): + self.output_dir = output_dir + self.label_file = label_file + self.label_list = [] + with open(label_file) as fin: + for line in fin: + self.label_list.append(line.strip()) + self.clsid2catid = {i: i for i in range(len(self.label_list))} + self.catid2name = {i: name for i, name in enumerate(self.label_list)} + + def _offset_to_lengths(self, lod): + offset = lod[0] + lengths = [offset[i + 1] - offset[i] for i in range(len(offset) - 1)] + return [lengths] + + def _bbox2out(self, results, clsid2catid, is_bbox_normalized=False): + xywh_res = [] + for t in results: + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + if bboxes.shape == (1, 1) or bboxes is None: + continue + + k = 0 + for i in range(len(lengths)): + num = lengths[i] + for j in range(num): + dt = bboxes[k] + clsid, score, xmin, ymin, xmax, ymax = dt.tolist() + catid = (clsid2catid[int(clsid)]) + + if is_bbox_normalized: + xmin, ymin, xmax, ymax = \ + self.clip_bbox([xmin, ymin, xmax, ymax]) + w = xmax - xmin + h = ymax - ymin + im_shape = t['im_shape'][0][i].tolist() + im_height, im_width = int(im_shape[0]), int(im_shape[1]) + xmin *= im_width + ymin *= im_height + w *= im_width + h *= im_height + else: + w = xmax - xmin + 1 + h = ymax - ymin + 1 + + bbox = [xmin, ymin, w, h] + coco_res = { + 'category_id': catid, + 'bbox': bbox, + 'score': score + } + xywh_res.append(coco_res) + k += 1 + return xywh_res + + def _get_bbox_result(self, fetch_map, fetch_name, clsid2catid): + result = {} + is_bbox_normalized = False + output = fetch_map[fetch_name] + lod = [fetch_map[fetch_name + '.lod']] + lengths = self._offset_to_lengths(lod) + np_data = np.array(output) + result['bbox'] = (np_data, lengths) + result['im_id'] = np.array([[0]]) + + bbox_results = self._bbox2out([result], clsid2catid, is_bbox_normalized) + return bbox_results + + def color_map(self, num_classes): + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = np.array(color_map).reshape(-1, 3) + return color_map + + def draw_bbox(self, image, catid2name, bboxes, threshold, color_list): + """ + draw bbox on image + """ + draw = ImageDraw.Draw(image) + + for dt in np.array(bboxes): + catid, bbox, score = dt['category_id'], dt['bbox'], dt['score'] + if score < threshold: + continue + + xmin, ymin, w, h = bbox + xmax = xmin + w + ymax = ymin + h + + color = tuple(color_list[catid]) + + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill=color) + + # draw label + text = "{} {:.2f}".format(catid2name[catid], score) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + return image + + def visualize(self, infer_img, bbox_results, catid2name, num_classes): + image = Image.open(infer_img).convert('RGB') + color_list = self.color_map(num_classes) + image = self.draw_bbox(image, self.catid2name, bbox_results, 0.5, + color_list) + image_path = os.path.split(infer_img)[-1] + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + out_path = os.path.join(self.output_dir, image_path) + image.save(out_path, quality=95) + + def __call__(self, image_with_bbox): + fetch_name = "" + for key in image_with_bbox: + if key == "image": + continue + if ".lod" in key: + continue + fetch_name = key + bbox_result = self._get_bbox_result(image_with_bbox, fetch_name, + self.clsid2catid) + if os.path.isdir(self.output_dir) is False: + os.mkdir(self.output_dir) + self.visualize(image_with_bbox["image"], bbox_result, self.catid2name, + len(self.label_list)) + if os.path.isdir(self.output_dir) is False: + os.mkdir(self.output_dir) + bbox_file = os.path.join(self.output_dir, 'bbox.json') + with open(bbox_file, 'w') as f: + json.dump(bbox_result, f, indent=4) + + def __repr__(self): + return self.__class__.__name__ + "label_file: {1}, output_dir: {2}".format( + self.label_file, self.output_dir) + + +class Sequential(object): + """ + Args: + sequence (sequence of ``Transform`` objects): list of transforms to chain. + + This API references some of the design pattern of torchvision + Users can simply use this API in training as well + + Example: + >>> image_reader.Sequnece([ + >>> transforms.CenterCrop(10), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + def __repr__(self): + format_string_ = self.__class__.__name__ + '(' + for t in self.transforms: + format_string_ += '\n' + format_string_ += ' {0}'.format(t) + format_string_ += '\n)' + return format_string_ + + +class RGB2BGR(object): + def __init__(self): + pass + + def __call__(self, img): + return img[:, :, ::-1] + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class BGR2RGB(object): + def __init__(self): + pass + + def __call__(self, img): + return img[:, :, ::-1] + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class File2Image(object): + def __init__(self): + pass + + def __call__(self, img_path): + fin = open(img_path) + sample = fin.read() + data = np.fromstring(sample, np.uint8) + img = cv2.imdecode(data, cv2.IMREAD_COLOR) + ''' + img = cv2.imread(img_path, -1) + channels = img.shape[2] + ori_h = img.shape[0] + ori_w = img.shape[1] + ''' + return img + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class URL2Image(object): + def __init__(self): + pass + + def __call__(self, img_url): + resp = urllib.urlopen(img_url) + sample = resp.read() + data = np.fromstring(sample, np.uint8) + img = cv2.imdecode(data, cv2.IMREAD_COLOR) + return img + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Base64ToImage(object): + def __init__(self): + pass + + def __call__(self, img_base64): + img = base64.b64decode(img_base64) + return img + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Div(object): + """ divide by some float number """ + + def __init__(self, value): + self.value = value + + def __call__(self, img): + """ + Args: + img (numpy array): (int8 numpy array) + + Returns: + img (numpy array): (float32 numpy array) + """ + img = img.astype('float32') / self.value + + return img + + def __repr__(self): + return self.__class__.__name__ + "({})".format(self.value) + + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + + """ + + def __init__(self, mean, std, channel_first=False): + self.mean = mean + self.std = std + self.channel_first = channel_first + + def __call__(self, img): + """ + Args: + img (numpy array): (C, H, W) to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + return F.normalize(img, self.mean, self.std, self.channel_first) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, + self.std) + + +class Lambda(object): + """Apply a user-defined lambda as a transform. + Very shame to just copy from + https://github.com/pytorch/vision/blob/master/torchvision/transforms/transforms.py#L301 + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + assert callable(lambd), repr(type(lambd) + .__name__) + " object is not callable" + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class CenterCrop(object): + """Crops the given Image at the center. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, img): + """ + Args: + img (numpy array): Image to be cropped. + + Returns: + numpy array Image: Cropped image. + """ + return F.crop(img, self.size, True) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class Resize(object): + """Resize the input numpy array Image to the given size. + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``None`` + """ + + def __init__(self, size, max_size=2147483647, interpolation=None): + self.size = size + self.max_size = max_size + self.interpolation = interpolation + + def __call__(self, img): + return F.resize(img, self.size, self.max_size, self.interpolation) + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, max_size={1}, interpolation={2})'.format( + self.size, self.max_size, + _cv2_interpolation_to_str[self.interpolation]) + + +class Transpose(object): + def __init__(self, transpose_target): + self.transpose_target = transpose_target + + def __call__(self, img): + return F.transpose(img, self.transpose_target) + return img + + def __repr__(self): + format_string = self.__class__.__name__ + \ + "({})".format(self.transpose_target) + return format_string class ImageReader(): diff --git a/python/paddle_serving_app/reader/test_image_reader.py b/python/paddle_serving_app/reader/test_image_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..f2dc52771919651f586e1e9720fe0ae8f82e8c12 --- /dev/null +++ b/python/paddle_serving_app/reader/test_image_reader.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from image_reader import File2Image +from image_reader import URL2Image +from image_reader import Sequential +from image_reader import Normalize +from image_reader import CenterCrop +from image_reader import Resize + +seq = Sequential([ + File2Image(), CenterCrop(30), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Resize((5, 5)) +]) + +url = "daisy.jpg" +for x in range(100): + img = seq(url) + print(img.shape) diff --git a/python/paddle_serving_app/version.py b/python/paddle_serving_app/version.py index 80f647be56d09740adfb9d68dd47bb0b1fa2c985..766bf4e397e46153193b1e3cac6fed5323241c45 100644 --- a/python/paddle_serving_app/version.py +++ b/python/paddle_serving_app/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Paddle Serving App version string """ -serving_app_version = "0.0.1" +serving_app_version = "0.0.3" diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 3380934931d5872afca81934724f72614bb64a13..09d3a7e2a72d871d0b6015747150bd27be6cde27 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -125,7 +125,11 @@ class Client(object): lib_path = os.path.dirname(paddle_serving_client.__file__) client_path = os.path.join(lib_path, 'serving_client.so') lib_path = os.path.join(lib_path, 'lib') - os.system('patchelf --set-rpath {} {}'.format(lib_path, client_path)) + ld_path = os.getenv('LD_LIBRARY_PATH') + if ld_path == None: + os.environ['LD_LIBRARY_PATH'] = lib_path + elif ld_path not in lib_path: + os.environ['LD_LIBRARY_PATH'] = ld_path + ':' + lib_path def load_client_config(self, path): from .serving_client import PredictorClient diff --git a/python/paddle_serving_client/version.py b/python/paddle_serving_client/version.py index 99322ee8280a66a54371b296905d54f0766b016d..4870767dfcb95f9502dfa5880a85b1c11c62964f 100644 --- a/python/paddle_serving_client/version.py +++ b/python/paddle_serving_client/version.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Paddle Serving Client version string """ -serving_client_version = "0.2.0" -serving_server_version = "0.2.0" -module_proto_version = "0.2.0" +serving_client_version = "0.2.2" +serving_server_version = "0.2.2" +module_proto_version = "0.2.2" diff --git a/python/paddle_serving_server/version.py b/python/paddle_serving_server/version.py index 99322ee8280a66a54371b296905d54f0766b016d..4870767dfcb95f9502dfa5880a85b1c11c62964f 100644 --- a/python/paddle_serving_server/version.py +++ b/python/paddle_serving_server/version.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Paddle Serving Client version string """ -serving_client_version = "0.2.0" -serving_server_version = "0.2.0" -module_proto_version = "0.2.0" +serving_client_version = "0.2.2" +serving_server_version = "0.2.2" +module_proto_version = "0.2.2" diff --git a/python/paddle_serving_server_gpu/version.py b/python/paddle_serving_server_gpu/version.py index 99322ee8280a66a54371b296905d54f0766b016d..4870767dfcb95f9502dfa5880a85b1c11c62964f 100644 --- a/python/paddle_serving_server_gpu/version.py +++ b/python/paddle_serving_server_gpu/version.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Paddle Serving Client version string """ -serving_client_version = "0.2.0" -serving_server_version = "0.2.0" -module_proto_version = "0.2.0" +serving_client_version = "0.2.2" +serving_server_version = "0.2.2" +module_proto_version = "0.2.2"