未验证 提交 c854d813 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #967 from wangjiawei04/autoshape

Auto shape for Web Service
...@@ -20,7 +20,7 @@ from paddle_serving_server import OpMaker, OpSeqMaker, Server ...@@ -20,7 +20,7 @@ from paddle_serving_server import OpMaker, OpSeqMaker, Server
from paddle_serving_client import Client from paddle_serving_client import Client
from contextlib import closing from contextlib import closing
import socket import socket
import numpy as np
from paddle_serving_server import pipeline from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op from paddle_serving_server.pipeline import Op
...@@ -64,8 +64,8 @@ class WebService(object): ...@@ -64,8 +64,8 @@ class WebService(object):
f = open(client_config, 'r') f = open(client_config, 'r')
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf) str(f.read()), model_conf)
self.feed_names = [var.alias_name for var in model_conf.feed_var] self.feed_vars = {var.name: var for var in model_conf.feed_var}
self.fetch_names = [var.alias_name for var in model_conf.fetch_var] self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
def _launch_rpc_service(self): def _launch_rpc_service(self):
op_maker = OpMaker() op_maker = OpMaker()
...@@ -201,6 +201,15 @@ class WebService(object): ...@@ -201,6 +201,15 @@ class WebService(object):
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
is_batch = True is_batch = True
feed_dict = {}
for var_name in self.feed_vars.keys():
feed_dict[var_name] = []
for feed_ins in feed:
for key in feed_ins:
feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:])
feed = {}
for key in feed_dict:
feed[key] = np.concatenate(feed_dict[key], axis=0)
return feed, fetch, is_batch return feed, fetch, is_batch
def postprocess(self, feed=[], fetch=[], fetch_map=None): def postprocess(self, feed=[], fetch=[], fetch_map=None):
......
...@@ -70,8 +70,8 @@ class WebService(object): ...@@ -70,8 +70,8 @@ class WebService(object):
f = open(client_config, 'r') f = open(client_config, 'r')
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf) str(f.read()), model_conf)
self.feed_names = [var.alias_name for var in model_conf.feed_var] self.feed_vars = {var.name: var for var in model_conf.feed_var}
self.fetch_names = [var.alias_name for var in model_conf.fetch_var] self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
def set_gpus(self, gpus): def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
...@@ -278,6 +278,15 @@ class WebService(object): ...@@ -278,6 +278,15 @@ class WebService(object):
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
is_batch = True is_batch = True
feed_dict = {}
for var_name in self.feed_vars.keys():
feed_dict[var_name] = []
for feed_ins in feed:
for key in feed_ins:
feed_dict[key].append(np.array(feed_ins[key]).reshape(list(self.feed_vars[key].shape))[np.newaxis,:])
feed = {}
for key in feed_dict:
feed[key] = np.concatenate(feed_dict[key], axis=0)
return feed, fetch, is_batch return feed, fetch, is_batch
def postprocess(self, feed=[], fetch=[], fetch_map=None): def postprocess(self, feed=[], fetch=[], fetch_map=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册