From 707ddaf183192405d974387e36e6f9b85ae9f98f Mon Sep 17 00:00:00 2001 From: MRXLT Date: Wed, 9 Sep 2020 08:51:51 +0000 Subject: [PATCH] fix imagenet --- python/examples/imagenet/benchmark.py | 3 ++- python/examples/imagenet/resnet50_web_service.py | 4 ++-- python/paddle_serving_app/reader/__init__.py | 2 +- python/paddle_serving_server/web_service.py | 2 +- python/paddle_serving_server_gpu/web_service.py | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/examples/imagenet/benchmark.py b/python/examples/imagenet/benchmark.py index 0181b873..8bbb8875 100644 --- a/python/examples/imagenet/benchmark.py +++ b/python/examples/imagenet/benchmark.py @@ -90,6 +90,7 @@ def single_func(idx, resource): image = base64.b64encode( open("./image_data/n01440764/" + file_list[i]).read()) else: + image_path = "./image_data/n01440764/" + file_list[i] image = base64.b64encode(open(image_path, "rb").read()).decode( "utf-8") req = json.dumps({"feed": [{"image": image}], "fetch": ["score"]}) @@ -106,7 +107,7 @@ if __name__ == '__main__': endpoint_list = [ "127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295" ] - turns = 100 + turns = 1 start = time.time() result = multi_thread_runner.run( single_func, args.thread, {"endpoint": endpoint_list, diff --git a/python/examples/imagenet/resnet50_web_service.py b/python/examples/imagenet/resnet50_web_service.py index e7d19149..a3f4709c 100644 --- a/python/examples/imagenet/resnet50_web_service.py +++ b/python/examples/imagenet/resnet50_web_service.py @@ -13,7 +13,7 @@ # limitations under the License. import sys from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize +from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize, Base64ToImage if len(sys.argv) != 4: print("python resnet50_web_service.py model device port") @@ -30,7 +30,7 @@ else: class ImageService(WebService): def init_imagenet_setting(self): self.seq = Sequential([ - URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose( + Base64ToImage(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose( (2, 0, 1)), Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True) ]) diff --git a/python/paddle_serving_app/reader/__init__.py b/python/paddle_serving_app/reader/__init__.py index 93e2cd76..e9fd3154 100644 --- a/python/paddle_serving_app/reader/__init__.py +++ b/python/paddle_serving_app/reader/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .chinese_bert_reader import ChineseBertReader -from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize +from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, Base64ToImage from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor from .image_reader import RCNNPostprocess, SegPostprocess, PadStride from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 78f57487..9430da83 100644 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -123,7 +123,7 @@ class WebService(object): feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) result = {"result": result} except ValueError as err: - result = {"result": err} + result = {"result": str(err)} return result def run_rpc_service(self): diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 4154e824..ea72f186 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -178,7 +178,7 @@ class WebService(object): feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) result = {"result": result} except ValueError as err: - result = {"result": err} + result = {"result": str(err)} return result def run_rpc_service(self): -- GitLab