diff --git a/python/examples/pipeline/imdb_model_ensemble/test_pipeline_server.py b/python/examples/pipeline/imdb_model_ensemble/test_pipeline_server.py index 3f96f619608c608aa4e21c22be8f05da218ae3a7..8ff147eb488ca7c2b8e17007c6e3596762ec0203 100644 --- a/python/examples/pipeline/imdb_model_ensemble/test_pipeline_server.py +++ b/python/examples/pipeline/imdb_model_ensemble/test_pipeline_server.py @@ -22,21 +22,19 @@ from paddle_serving_app.reader import IMDBDataset logging.basicConfig( format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M', - # level=logging.DEBUG) level=logging.INFO) class ImdbOp(Op): + def load_user_resources(self): + self.imdb_dataset = IMDBDataset() + self.imdb_dataset.load_resource('imdb.vocab') + def preprocess(self, input_data): data = input_data.parse() - imdb_dataset = IMDBDataset() - imdb_dataset.load_resource('imdb.vocab') - word_ids, _ = imdb_dataset.get_words_and_label(data['words']) + word_ids, _ = self.imdb_dataset.get_words_and_label(data['words']) return {"words": word_ids} - # def postprocess(self, fetch_data): - # return {key: str(value) for key, value in fetch_data.items()} - class CombineOp(Op): def preprocess(self, input_data): @@ -45,7 +43,7 @@ class CombineOp(Op): data = channeldata.parse() logging.info("{}: {}".format(op_name, data["prediction"])) combined_prediction += data["prediction"] - data = {"prediction": str(combined_prediction / 2)} + data = {"prediction": combined_prediction / 2} return data @@ -77,7 +75,7 @@ combine_op = CombineOp( server = PipelineServer() server.add_ops([read_op, bow_op, cnn_op, combine_op]) -# server.set_response_op(bow_op) +#server.set_response_op(bow_op) server.set_response_op(combine_op) server.prepare_server('config.yml') server.run_server() diff --git a/python/pipeline/channel.py b/python/pipeline/channel.py index 2a1d82d7558262d294e3e345583b34ec97fd6031..cbe81249817e7ba21928398b1ca0b502ed098530 100644 --- a/python/pipeline/channel.py +++ b/python/pipeline/channel.py @@ -34,7 +34,8 @@ class ChannelDataEcode(enum.Enum): NOT_IMPLEMENTED = 2 TYPE_ERROR = 3 RPC_PACKAGE_ERROR = 4 - UNKNOW = 5 + CLIENT_ERROR = 5 + UNKNOW = 6 class ChannelDataType(enum.Enum): diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index b13408d2c214296d60e51b2d41dff08e8c34c5c8..230ffbd4bb3ca4050bfe2c0f49b06c8c43e4c682 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -204,6 +204,9 @@ class Op(object): threads.append(t) return threads + def load_user_resources(self): + pass + def _run(self, concurrency_idx, input_channel, output_channels, client_type): def get_log_func(op_info_prefix): @@ -220,6 +223,9 @@ class Op(object): self.init_client(client_type, self._client_config, self._server_endpoints, self._fetch_names) + # load user resources + self.load_user_resources() + self._is_run = True while self._is_run: self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid)) @@ -319,6 +325,16 @@ class Op(object): continue self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid)) + # op client return None + if midped_data is None: + self._push_to_output_channels( + ChannelData( + ecode=ChannelDataEcode.CLIENT_ERROR.value, + error_info=log( + "predict failed. pls check the server side."), + data_id=data_id), + output_channels) + continue else: midped_data = preped_data diff --git a/python/pipeline/pipeline_client.py b/python/pipeline/pipeline_client.py index 06b6944366a5684960695271545ac9edbbfb291b..7e9af1617392eda5770b11afa7a6e8d6da3b8b88 100644 --- a/python/pipeline/pipeline_client.py +++ b/python/pipeline/pipeline_client.py @@ -14,6 +14,7 @@ # pylint: disable=doc-string-missing import grpc import numpy as np +from numpy import array from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2_grpc @@ -51,5 +52,10 @@ class PipelineClient(object): for idx, key in enumerate(resp.key): if key not in fetch: continue - fetch_map[key] = resp.value[idx] + data = resp.value[idx] + try: + data = eval(resp.value[idx]) + except Exception as e: + pass + fetch_map[key] = data return fetch_map diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index 904139467d771f0f31e2b60d29eff11a64df24e7..94d517fc3ebadd8f1e71e7257763161e51b92355 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -140,18 +140,26 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): if resp.ecode == ChannelDataEcode.OK.value: if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value: feed = channeldata.parse() + # ndarray to string for name, var in feed.items(): - resp.value.append(var.tostring()) #TODO: no shape and type + resp.value.append(var.__repr__()) resp.key.append(name) elif channeldata.datatype == ChannelDataType.DICT.value: feed = channeldata.parse() for name, var in feed.items(): - resp.value.append(str(var)) + if not isinstance(var, str): + resp.ecode = ChannelDataEcode.TYPE_ERROR.value + resp.error_info = self._log( + "fetch var type must be str({}).".format( + type(var))) + break + resp.value.append(var) resp.key.append(name) else: - raise TypeError( - self._log("Error type({}) in datatype.".format( - channeldata.datatype))) + resp.ecode = ChannelDataEcode.TYPE_ERROR.value + resp.error_info = self._log( + "Error type({}) in datatype.".format(channeldata.datatype)) + logging.error(resp.error_info) else: resp.error_info = channeldata.error_info return resp