From 6691abb784026259c45148a31d61e873fd4c8216 Mon Sep 17 00:00:00 2001 From: wangjiawei04 Date: Fri, 15 Jan 2021 08:36:09 +0000 Subject: [PATCH] fix web service --- python/paddle_serving_server/web_service.py | 15 ++++++++++++--- python/paddle_serving_server_gpu/web_service.py | 13 +++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 1f035db9..fbe48180 100644 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -20,7 +20,7 @@ from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_client import Client from contextlib import closing import socket - +import numpy as np from paddle_serving_server import pipeline from paddle_serving_server.pipeline import Op @@ -64,8 +64,8 @@ class WebService(object): f = open(client_config, 'r') model_conf = google.protobuf.text_format.Merge( str(f.read()), model_conf) - self.feed_names = [var.alias_name for var in model_conf.feed_var] - self.fetch_names = [var.alias_name for var in model_conf.fetch_var] + self.feed_vars = {var.name: var for var in model_conf.feed_var} + self.fetch_vars = {var.name: var for var in model_conf.fetch_var} def _launch_rpc_service(self): op_maker = OpMaker() @@ -201,6 +201,15 @@ class WebService(object): def preprocess(self, feed=[], fetch=[]): print("This API will be deprecated later. Please do not use it") is_batch = True + feed_dict = {} + for var_name in self.feed_vars.keys(): + feed_dict[var_name] = [] + for feed_ins in feed: + for key in feed_ins: + feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:]) + feed = {} + for key in feed_dict: + feed[key] = np.concatenate(feed_dict[key], axis=0) return feed, fetch, is_batch def postprocess(self, feed=[], fetch=[], fetch_map=None): diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 4b89d90e..6e7fc2c1 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -70,8 +70,8 @@ class WebService(object): f = open(client_config, 'r') model_conf = google.protobuf.text_format.Merge( str(f.read()), model_conf) - self.feed_names = [var.alias_name for var in model_conf.feed_var] - self.fetch_names = [var.alias_name for var in model_conf.fetch_var] + self.feed_vars = {var.name: var for var in model_conf.feed_var} + self.fetch_vars = {var.name: var for var in model_conf.fetch_var} def set_gpus(self, gpus): print("This API will be deprecated later. Please do not use it") @@ -278,6 +278,15 @@ class WebService(object): def preprocess(self, feed=[], fetch=[]): print("This API will be deprecated later. Please do not use it") is_batch = True + feed_dict = {} + for var_name in self.feed_vars.keys(): + feed_dict[var_name] = [] + for feed_ins in feed: + for key in feed_ins: + feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:]) + feed = {} + for key in feed_dict: + feed[key] = np.concatenate(feed_dict[key], axis=0) return feed, fetch, is_batch def postprocess(self, feed=[], fetch=[], fetch_map=None): -- GitLab