提交 f89dc91a 编写于 作者: B barrierye

support ndarray2str

上级 92fb727f
...@@ -22,21 +22,19 @@ from paddle_serving_app.reader import IMDBDataset ...@@ -22,21 +22,19 @@ from paddle_serving_app.reader import IMDBDataset
logging.basicConfig( logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M', datefmt='%Y-%m-%d %H:%M',
# level=logging.DEBUG)
level=logging.INFO) level=logging.INFO)
class ImdbOp(Op): class ImdbOp(Op):
def load_user_resources(self):
self.imdb_dataset = IMDBDataset()
self.imdb_dataset.load_resource('imdb.vocab')
def preprocess(self, input_data): def preprocess(self, input_data):
data = input_data.parse() data = input_data.parse()
imdb_dataset = IMDBDataset() word_ids, _ = self.imdb_dataset.get_words_and_label(data['words'])
imdb_dataset.load_resource('imdb.vocab')
word_ids, _ = imdb_dataset.get_words_and_label(data['words'])
return {"words": word_ids} return {"words": word_ids}
# def postprocess(self, fetch_data):
# return {key: str(value) for key, value in fetch_data.items()}
class CombineOp(Op): class CombineOp(Op):
def preprocess(self, input_data): def preprocess(self, input_data):
...@@ -45,7 +43,7 @@ class CombineOp(Op): ...@@ -45,7 +43,7 @@ class CombineOp(Op):
data = channeldata.parse() data = channeldata.parse()
logging.info("{}: {}".format(op_name, data["prediction"])) logging.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"] combined_prediction += data["prediction"]
data = {"prediction": str(combined_prediction / 2)} data = {"prediction": combined_prediction / 2}
return data return data
...@@ -77,7 +75,7 @@ combine_op = CombineOp( ...@@ -77,7 +75,7 @@ combine_op = CombineOp(
server = PipelineServer() server = PipelineServer()
server.add_ops([read_op, bow_op, cnn_op, combine_op]) server.add_ops([read_op, bow_op, cnn_op, combine_op])
# server.set_response_op(bow_op) #server.set_response_op(bow_op)
server.set_response_op(combine_op) server.set_response_op(combine_op)
server.prepare_server('config.yml') server.prepare_server('config.yml')
server.run_server() server.run_server()
...@@ -34,7 +34,8 @@ class ChannelDataEcode(enum.Enum): ...@@ -34,7 +34,8 @@ class ChannelDataEcode(enum.Enum):
NOT_IMPLEMENTED = 2 NOT_IMPLEMENTED = 2
TYPE_ERROR = 3 TYPE_ERROR = 3
RPC_PACKAGE_ERROR = 4 RPC_PACKAGE_ERROR = 4
UNKNOW = 5 CLIENT_ERROR = 5
UNKNOW = 6
class ChannelDataType(enum.Enum): class ChannelDataType(enum.Enum):
......
...@@ -204,6 +204,9 @@ class Op(object): ...@@ -204,6 +204,9 @@ class Op(object):
threads.append(t) threads.append(t)
return threads return threads
def load_user_resources(self):
pass
def _run(self, concurrency_idx, input_channel, output_channels, def _run(self, concurrency_idx, input_channel, output_channels,
client_type): client_type):
def get_log_func(op_info_prefix): def get_log_func(op_info_prefix):
...@@ -220,6 +223,9 @@ class Op(object): ...@@ -220,6 +223,9 @@ class Op(object):
self.init_client(client_type, self._client_config, self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names) self._server_endpoints, self._fetch_names)
# load user resources
self.load_user_resources()
self._is_run = True self._is_run = True
while self._is_run: while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
...@@ -319,6 +325,16 @@ class Op(object): ...@@ -319,6 +325,16 @@ class Op(object):
continue continue
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, self._profiler_record("{}-midp#{}_1".format(op_info_prefix,
tid)) tid))
# op client return None
if midped_data is None:
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log(
"predict failed. pls check the server side."),
data_id=data_id),
output_channels)
continue
else: else:
midped_data = preped_data midped_data = preped_data
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import grpc import grpc
import numpy as np import numpy as np
from numpy import array
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
...@@ -51,5 +52,10 @@ class PipelineClient(object): ...@@ -51,5 +52,10 @@ class PipelineClient(object):
for idx, key in enumerate(resp.key): for idx, key in enumerate(resp.key):
if key not in fetch: if key not in fetch:
continue continue
fetch_map[key] = resp.value[idx] data = resp.value[idx]
try:
data = eval(resp.value[idx])
except Exception as e:
pass
fetch_map[key] = data
return fetch_map return fetch_map
...@@ -140,18 +140,26 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -140,18 +140,26 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
if resp.ecode == ChannelDataEcode.OK.value: if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value: if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse() feed = channeldata.parse()
# ndarray to string
for name, var in feed.items(): for name, var in feed.items():
resp.value.append(var.tostring()) #TODO: no shape and type resp.value.append(var.__repr__())
resp.key.append(name) resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value: elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse() feed = channeldata.parse()
for name, var in feed.items(): for name, var in feed.items():
resp.value.append(str(var)) if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(
type(var)))
break
resp.value.append(var)
resp.key.append(name) resp.key.append(name)
else: else:
raise TypeError( resp.ecode = ChannelDataEcode.TYPE_ERROR.value
self._log("Error type({}) in datatype.".format( resp.error_info = self._log(
channeldata.datatype))) "Error type({}) in datatype.".format(channeldata.datatype))
logging.error(resp.error_info)
else: else:
resp.error_info = channeldata.error_info resp.error_info = channeldata.error_info
return resp return resp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册