提交 527e0cb2 编写于 作者: W wangjiawei04

update pipeline

上级 c8ba9440
......@@ -27,7 +27,7 @@ import logging
import enum
import copy
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class ChannelDataEcode(enum.Enum):
......@@ -102,12 +102,32 @@ class ChannelData(object):
def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
if isinstance(npdata, list):
# batch data
for sample in npdata:
if not isinstance(sample, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
for _, value in sample.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
return ecode, error_info
elif isinstance(npdata, dict):
# batch_size = 1
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
break
else:
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(npdata))
return ecode, error_info
def parse(self):
......
......@@ -25,7 +25,7 @@ from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
_op_name_gen = NameGenerator("Op")
......@@ -142,7 +142,7 @@ class Op(object):
_LOGGER.debug(self._log("get call_result"))
return call_result
def postprocess(self, fetch_dict):
def postprocess(self, input_dict, fetch_dict):
return fetch_dict
def stop(self):
......@@ -267,10 +267,10 @@ class Op(object):
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, midped_data, data_id, log_func):
def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
output_data, error_channeldata = None, None
try:
postped_data = self.postprocess(midped_data)
postped_data = self.postprocess(input_dict, midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
......@@ -359,8 +359,8 @@ class Op(object):
# postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(midped_data,
data_id, log)
output_data, error_channeldata = self._run_postprocess(
parsed_data, midped_data, data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......
......@@ -20,7 +20,7 @@ import functools
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class PipelineClient(object):
......@@ -52,7 +52,7 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key):
if key not in fetch:
if fetch is not None and key not in fetch:
continue
data = resp.value[idx]
try:
......@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map[key] = data
return fetch_map
def predict(self, feed_dict, fetch, asyn=False):
def predict(self, feed_dict, fetch=None, asyn=False):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list):
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict)
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp)
return self._unpack_response_package(resp, fetch)
else:
call_future = self._stub.inference.future(req)
return PipelinePredictFuture(
......
......@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from .profiler import TimeProfiler
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
_profiler = TimeProfiler()
......@@ -384,7 +384,7 @@ class PipelineServer(object):
def prepare_server(self, yml_file):
with open(yml_file) as f:
yml_config = yaml.load(f.read())
yml_config = yaml.load(f.read(), Loader=yaml.FullLoader)
self._port = yml_config.get('port', 8080)
if not self._port_is_available(self._port):
raise SystemExit("Prot {} is already used".format(self._port))
......
......@@ -24,7 +24,7 @@ else:
raise Exception("Error Python version")
import time
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class TimeProfiler(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册