diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 1f035db9262ffbd8e031c9b0018877eb2ba6fad2..fbe48180867faf9f2baba71fc3c5c8cf6ab771e2 100644 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -20,7 +20,7 @@ from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_client import Client from contextlib import closing import socket - +import numpy as np from paddle_serving_server import pipeline from paddle_serving_server.pipeline import Op @@ -64,8 +64,8 @@ class WebService(object): f = open(client_config, 'r') model_conf = google.protobuf.text_format.Merge( str(f.read()), model_conf) - self.feed_names = [var.alias_name for var in model_conf.feed_var] - self.fetch_names = [var.alias_name for var in model_conf.fetch_var] + self.feed_vars = {var.name: var for var in model_conf.feed_var} + self.fetch_vars = {var.name: var for var in model_conf.fetch_var} def _launch_rpc_service(self): op_maker = OpMaker() @@ -201,6 +201,15 @@ class WebService(object): def preprocess(self, feed=[], fetch=[]): print("This API will be deprecated later. Please do not use it") is_batch = True + feed_dict = {} + for var_name in self.feed_vars.keys(): + feed_dict[var_name] = [] + for feed_ins in feed: + for key in feed_ins: + feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:]) + feed = {} + for key in feed_dict: + feed[key] = np.concatenate(feed_dict[key], axis=0) return feed, fetch, is_batch def postprocess(self, feed=[], fetch=[], fetch_map=None): diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 4b89d90ee6893c3fafd596dc8f6c5cabc3a248bf..6e7fc2c148dab721e74a7d1719c48849bbab3405 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -70,8 +70,8 @@ class WebService(object): f = open(client_config, 'r') model_conf = google.protobuf.text_format.Merge( str(f.read()), model_conf) - self.feed_names = [var.alias_name for var in model_conf.feed_var] - self.fetch_names = [var.alias_name for var in model_conf.fetch_var] + self.feed_vars = {var.name: var for var in model_conf.feed_var} + self.fetch_vars = {var.name: var for var in model_conf.fetch_var} def set_gpus(self, gpus): print("This API will be deprecated later. Please do not use it") @@ -278,6 +278,15 @@ class WebService(object): def preprocess(self, feed=[], fetch=[]): print("This API will be deprecated later. Please do not use it") is_batch = True + feed_dict = {} + for var_name in self.feed_vars.keys(): + feed_dict[var_name] = [] + for feed_ins in feed: + for key in feed_ins: + feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:]) + feed = {} + for key in feed_dict: + feed[key] = np.concatenate(feed_dict[key], axis=0) return feed, fetch, is_batch def postprocess(self, feed=[], fetch=[], fetch_map=None):