提交 dc95e6fe 编写于 作者: B barrierye

add asyn impl

上级 5ba1b5fd
...@@ -16,6 +16,7 @@ import grpc ...@@ -16,6 +16,7 @@ import grpc
import numpy as np import numpy as np
from numpy import * from numpy import *
import logging import logging
import functools
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
...@@ -31,7 +32,7 @@ class PipelineClient(object): ...@@ -31,7 +32,7 @@ class PipelineClient(object):
self._stub = pipeline_service_pb2_grpc.PipelineServiceStub( self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
self._channel) self._channel)
def _pack_data_for_infer(self, feed_dict): def _pack_request_package(self, feed_dict):
req = pipeline_service_pb2.Request() req = pipeline_service_pb2.Request()
for key, value in feed_dict.items(): for key, value in feed_dict.items():
if not isinstance(value, str): if not isinstance(value, str):
...@@ -40,15 +41,7 @@ class PipelineClient(object): ...@@ -40,15 +41,7 @@ class PipelineClient(object):
req.value.append(value) req.value.append(value)
return req return req
def predict(self, feed_dict, fetch): def _unpack_response_package(self, resp, fetch):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list):
raise TypeError(
"fetch_with_type must be list type with format: [name].")
req = self._pack_data_for_infer(feed_dict)
resp = self._stub.inference(req)
if resp.ecode != 0: if resp.ecode != 0:
return {"ecode": resp.ecode, "error_info": resp.error_info} return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode} fetch_map = {"ecode": resp.ecode}
...@@ -62,3 +55,30 @@ class PipelineClient(object): ...@@ -62,3 +55,30 @@ class PipelineClient(object):
pass pass
fetch_map[key] = data fetch_map[key] = data
return fetch_map return fetch_map
def predict(self, feed_dict, fetch, asyn=False):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict)
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp)
else:
call_future = self._stub.inference.future(req)
return PipelinePredictFuture(
call_future,
functools.partial(
self._unpack_response_package, fetch=fetch))
class PipelinePredictFuture(object):
def __init__(self, call_future, callback_func):
self.call_future_ = call_future
self.callback_func_ = callback_func
def result(self):
resp = self.call_future_.result()
return self.callback_func_(resp)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册