提交 ce43cb50 编写于 作者: W wangjiawei04

update pipeline

上级 b9782cd9
...@@ -27,7 +27,7 @@ import logging ...@@ -27,7 +27,7 @@ import logging
import enum import enum
import copy import copy
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
class ChannelDataEcode(enum.Enum): class ChannelDataEcode(enum.Enum):
...@@ -97,17 +97,37 @@ class ChannelData(object): ...@@ -97,17 +97,37 @@ class ChannelData(object):
error_info = "the value of data must " \ error_info = "the value of data must " \
"be dict, but get {}.".format(type(dictdata)) "be dict, but get {}.".format(type(dictdata))
return ecode, error_info return ecode, error_info
@staticmethod @staticmethod
def check_npdata(npdata): def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
error_info = None error_info = None
for _, value in npdata.items(): if isinstance(npdata, list):
if not isinstance(value, np.ndarray): # batch data
ecode = ChannelDataEcode.TYPE_ERROR.value for sample in npdata:
error_info = "the value of data must " \ if not isinstance(sample, dict):
"be np.ndarray, but get {}.".format(type(value)) ecode = ChannelDataEcode.TYPE_ERROR.value
break error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
for _, value in sample.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
return ecode, error_info
elif isinstance(npdata, dict):
# batch_size = 1
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
break
else:
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(npdata))
return ecode, error_info return ecode, error_info
def parse(self): def parse(self):
......
...@@ -25,7 +25,7 @@ from .proto import pipeline_service_pb2 ...@@ -25,7 +25,7 @@ from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator from .util import NameGenerator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
_op_name_gen = NameGenerator("Op") _op_name_gen = NameGenerator("Op")
...@@ -142,7 +142,7 @@ class Op(object): ...@@ -142,7 +142,7 @@ class Op(object):
_LOGGER.debug(self._log("get call_result")) _LOGGER.debug(self._log("get call_result"))
return call_result return call_result
def postprocess(self, fetch_dict): def postprocess(self, input_dict, fetch_dict):
return fetch_dict return fetch_dict
def stop(self): def stop(self):
...@@ -267,10 +267,10 @@ class Op(object): ...@@ -267,10 +267,10 @@ class Op(object):
midped_data = preped_data midped_data = preped_data
return midped_data, error_channeldata return midped_data, error_channeldata
def _run_postprocess(self, midped_data, data_id, log_func): def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
output_data, error_channeldata = None, None output_data, error_channeldata = None, None
try: try:
postped_data = self.postprocess(midped_data) postped_data = self.postprocess(input_dict, midped_data)
except Exception as e: except Exception as e:
error_info = log_func(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
...@@ -359,8 +359,8 @@ class Op(object): ...@@ -359,8 +359,8 @@ class Op(object):
# postprocess # postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(midped_data, output_data, error_channeldata = self._run_postprocess(
data_id, log) parsed_data, midped_data, data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid)) self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(error_channeldata,
......
...@@ -20,7 +20,7 @@ import functools ...@@ -20,7 +20,7 @@ import functools
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
class PipelineClient(object): class PipelineClient(object):
...@@ -52,7 +52,7 @@ class PipelineClient(object): ...@@ -52,7 +52,7 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info} return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode} fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key): for idx, key in enumerate(resp.key):
if key not in fetch: if fetch is not None and key not in fetch:
continue continue
data = resp.value[idx] data = resp.value[idx]
try: try:
...@@ -62,16 +62,16 @@ class PipelineClient(object): ...@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map[key] = data fetch_map[key] = data
return fetch_map return fetch_map
def predict(self, feed_dict, fetch, asyn=False): def predict(self, feed_dict, fetch=None, asyn=False):
if not isinstance(feed_dict, dict): if not isinstance(feed_dict, dict):
raise TypeError( raise TypeError(
"feed must be dict type with format: {name: value}.") "feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list): if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].") raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict) req = self._pack_request_package(feed_dict)
if not asyn: if not asyn:
resp = self._stub.inference(req) resp = self._stub.inference(req)
return self._unpack_response_package(resp) return self._unpack_response_package(resp, fetch)
else: else:
call_future = self._stub.inference.future(req) call_future = self._stub.inference.future(req)
return PipelinePredictFuture( return PipelinePredictFuture(
......
...@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod ...@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from .profiler import TimeProfiler from .profiler import TimeProfiler
from .util import NameGenerator from .util import NameGenerator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
_profiler = TimeProfiler() _profiler = TimeProfiler()
...@@ -384,7 +384,7 @@ class PipelineServer(object): ...@@ -384,7 +384,7 @@ class PipelineServer(object):
def prepare_server(self, yml_file): def prepare_server(self, yml_file):
with open(yml_file) as f: with open(yml_file) as f:
yml_config = yaml.load(f.read()) yml_config = yaml.load(f.read(), Loader=yaml.FullLoader)
self._port = yml_config.get('port', 8080) self._port = yml_config.get('port', 8080)
if not self._port_is_available(self._port): if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port)) raise SystemExit("Prot {} is already used".format(self._port))
......
...@@ -24,7 +24,7 @@ else: ...@@ -24,7 +24,7 @@ else:
raise Exception("Error Python version") raise Exception("Error Python version")
import time import time
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
class TimeProfiler(object): class TimeProfiler(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册