diff --git a/python/pipeline/pipeline_client.py b/python/pipeline/pipeline_client.py index 1df3e76130820b6989941540109ec4ab8c4b49ad..cf5e2e7758e458d55100ace47039dc17bb07677a 100644 --- a/python/pipeline/pipeline_client.py +++ b/python/pipeline/pipeline_client.py @@ -16,6 +16,7 @@ import grpc import numpy as np from numpy import * import logging +import functools from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2_grpc @@ -31,7 +32,7 @@ class PipelineClient(object): self._stub = pipeline_service_pb2_grpc.PipelineServiceStub( self._channel) - def _pack_data_for_infer(self, feed_dict): + def _pack_request_package(self, feed_dict): req = pipeline_service_pb2.Request() for key, value in feed_dict.items(): if not isinstance(value, str): @@ -40,15 +41,7 @@ class PipelineClient(object): req.value.append(value) return req - def predict(self, feed_dict, fetch): - if not isinstance(feed_dict, dict): - raise TypeError( - "feed must be dict type with format: {name: value}.") - if not isinstance(fetch, list): - raise TypeError( - "fetch_with_type must be list type with format: [name].") - req = self._pack_data_for_infer(feed_dict) - resp = self._stub.inference(req) + def _unpack_response_package(self, resp, fetch): if resp.ecode != 0: return {"ecode": resp.ecode, "error_info": resp.error_info} fetch_map = {"ecode": resp.ecode} @@ -62,3 +55,30 @@ class PipelineClient(object): pass fetch_map[key] = data return fetch_map + + def predict(self, feed_dict, fetch, asyn=False): + if not isinstance(feed_dict, dict): + raise TypeError( + "feed must be dict type with format: {name: value}.") + if not isinstance(fetch, list): + raise TypeError("fetch must be list type with format: [name].") + req = self._pack_request_package(feed_dict) + if not asyn: + resp = self._stub.inference(req) + return self._unpack_response_package(resp) + else: + call_future = self._stub.inference.future(req) + return PipelinePredictFuture( + call_future, + functools.partial( + self._unpack_response_package, fetch=fetch)) + + +class PipelinePredictFuture(object): + def __init__(self, call_future, callback_func): + self.call_future_ = call_future + self.callback_func_ = callback_func + + def result(self): + resp = self.call_future_.result() + return self.callback_func_(resp)