提交 f068e257 编写于 作者: T TeslaZhao

python pipeline support a new feature, only receiving the frist arrival data from pre-ops

上级 25845a28
......@@ -28,6 +28,7 @@ import logging
import enum
import os
import copy
import time
_LOGGER = logging.getLogger(__name__)
......@@ -261,7 +262,11 @@ class ProcessChannel(object):
maintains the data obtained from queue.
"""
def __init__(self, manager, name=None, maxsize=0):
def __init__(self,
manager,
name=None,
maxsize=0,
channel_recv_frist_arrive=False):
# For queue multiprocess: after putting an object on
# an empty queue there may be an infinitessimal delay
# before the queue's :meth:`~Queue.empty`
......@@ -285,6 +290,9 @@ class ProcessChannel(object):
self._base_cursor = manager.Value('i', 0)
self._output_buf = manager.list()
self._cur_max_dataid = manager.Value('i', -1)
self._channel_recv_frist_arrive = channel_recv_frist_arrive
def get_maxsize(self):
return self._maxsize
......@@ -327,9 +335,10 @@ class ProcessChannel(object):
def push(self, channeldata, op_name=None):
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Enter channel::push producers:{}".
"(data_id={} log_id={}) Op({}) Enter channel::push producers:{}, time:{}".
format(channeldata.id, channeldata.log_id, op_name,
len(self._producers))))
len(self._producers), time.time())))
if len(self._producers) == 0:
_LOGGER.critical(
self._log(
......@@ -357,16 +366,55 @@ class ProcessChannel(object):
self._cv.notify_all()
notify_all_time = _time()
_LOGGER.debug(
"(data_id={}) Op({}) channel push cost! enter_cv:{} ms, push_que:{} ms, notify:{} ms, data_size:{}".
"(data_id={}) Op({}) channel push cost! enter_cv:{} ms, push_que:{} ms, notify:{} ms, data_size:{}, time:{}".
format(channeldata.id, op_name, (enter_cv_time - start_time)
* 1000, (push_que_time - enter_cv_time) * 1000, (
notify_all_time - push_que_time) * 1000,
channeldata.get_size()))
channeldata.get_size(), time.time()))
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Pushed data into internal queue.".
format(channeldata.id, channeldata.log_id, op_name)))
return True
elif self._channel_recv_frist_arrive == True:
start_time = _time()
with self._cv:
_LOGGER.debug(
"(data_id={}) Op({}) Channel({}) enter channel_recv_first_arrive. _cur_max_dataid:{}".
format(channeldata.id, op_name, self.name,
self._cur_max_dataid.value))
if channeldata.id > self._cur_max_dataid.value:
enter_cv_time = _time()
push_que_time = enter_cv_time
while self._stop.value == 0:
try:
self._que.put((channeldata.id, {
op_name: channeldata
}),
timeout=0)
push_que_time = _time()
self._cur_max_dataid.value = channeldata.id
break
except Queue.Full:
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
self._cv.notify_all()
notify_all_time = _time()
_LOGGER.debug(
"(data_id={}) Op({}) channel push cost! enter_cv:{} ms, push_que:{} ms, notify:{} ms, data_size:{}, time:{}".
format(channeldata.id, op_name, (
enter_cv_time - start_time) * 1000, (
push_que_time - enter_cv_time) * 1000, (
notify_all_time - push_que_time) * 1000,
channeldata.get_size(), time.time()))
else:
# log and drop it
_LOGGER.debug(
"(data_id={}) Op({}) send data is dropped! cur_max_dataid:{}".
format(channeldata.id, op_name,
self._cur_max_dataid.value))
return True
elif op_name is None:
_LOGGER.critical(
self._log(
......@@ -416,8 +464,8 @@ class ProcessChannel(object):
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Pushed data into internal_queue.".
format(data_id, log_id, op_name)))
"(data_id={} log_id={}) Op({}) Pushed data into internal_queue. time:{}".
format(data_id, log_id, op_name, time.time())))
self._cv.notify_all()
return True
......@@ -466,9 +514,9 @@ class ProcessChannel(object):
key = list(resp.keys())[0]
data_id = resp[key].id
_LOGGER.debug(
"(data_id={}) op({}) front cost enter_cv:{} ms, queue_get:{} ms".
"(data_id={}) op({}) front cost enter_cv:{} ms, queue_get:{} ms, time:{}".
format(data_id, op_name, (time_2 - time_1) / 1000.0, (
time_3 - time_2) / 1000.0))
time_3 - time_2) / 1000.0, time.time()))
if resp is not None:
list_values = list(resp.values())
_LOGGER.debug(
......@@ -503,9 +551,9 @@ class ProcessChannel(object):
list_values = list(channeldata.values())
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Pop ready item into output_buffer".
"(data_id={} log_id={}) Op({}) Pop ready item into output_buffer, time:{}".
format(list_values[0].id, list_values[0].log_id,
op_name)))
op_name, time.time())))
break
except Queue.Empty:
if timeout is not None:
......@@ -563,8 +611,9 @@ class ProcessChannel(object):
list_values = list(resp.values())
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Got data from output_buffer".
format(list_values[0].id, list_values[0].log_id, op_name)))
"(data_id={} log_id={}) Op({}) Got data from output_buffer, time:{}".
format(list_values[0].id, list_values[0].log_id, op_name,
time.time())))
return resp
def stop(self):
......@@ -603,7 +652,7 @@ class ThreadChannel(Queue.PriorityQueue):
maintains the data obtained from queue.
"""
def __init__(self, name=None, maxsize=-1):
def __init__(self, name=None, maxsize=-1, channel_recv_frist_arrive=False):
Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize
self.name = name
......@@ -621,6 +670,9 @@ class ThreadChannel(Queue.PriorityQueue):
self._base_cursor = 0
self._output_buf = []
self._channel_recv_frist_arrive = channel_recv_frist_arrive
self._cur_max_dataid = -1
def get_maxsize(self):
return self._maxsize
......@@ -664,6 +716,7 @@ class ThreadChannel(Queue.PriorityQueue):
_LOGGER.debug(
self._log("(data_id={} log_id={}) Op({}) Pushing data".format(
channeldata.id, channeldata.log_id, op_name)))
if len(self._producers) == 0:
_LOGGER.critical(
self._log(
......@@ -690,6 +743,29 @@ class ThreadChannel(Queue.PriorityQueue):
"(data_id={} log_id={}) Op({}) Pushed data into internal_queue.".
format(channeldata.id, channeldata.log_id, op_name)))
return True
elif self._channel_recv_frist_arrive is True:
with self._cv:
if channeldata.id > self._cur_max_dataid:
while self._stop is False:
try:
self.put((channeldata.id, {
op_name: channeldata
}),
timeout=0)
self._cur_max_dataid = channeldata.id
break
except Queue.Full:
self._cv.wait()
if self._stop:
raise ChannelStopError()
self._cv.notify_all()
else:
# log and drop it
_LOGGER.debug(
"(data_id={}) Op({}) send data is dropped! cur_max_dataid:{}".
format(channeldata.id, op_name, self._cur_max_dataid))
return True
elif op_name is None:
_LOGGER.critical(
self._log(
......
......@@ -63,6 +63,7 @@ class DAGExecutor(object):
self._retry = dag_conf["retry"]
self._server_use_profile = dag_conf["use_profile"]
channel_size = dag_conf["channel_size"]
channel_recv_frist_arrive = dag_conf["channel_recv_frist_arrive"]
self._is_thread_op = dag_conf["is_thread_op"]
tracer_conf = dag_conf["tracer"]
......@@ -79,7 +80,7 @@ class DAGExecutor(object):
self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, channel_size, build_dag_each_worker,
self._tracer)
self._tracer, channel_recv_frist_arrive)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
self._dag.start()
......@@ -480,7 +481,8 @@ class DAG(object):
"""
def __init__(self, request_name, response_op, use_profile, is_thread_op,
channel_size, build_dag_each_worker, tracer):
channel_size, build_dag_each_worker, tracer,
channel_recv_frist_arrive):
self._request_name = request_name
self._response_op = response_op
self._use_profile = use_profile
......@@ -488,6 +490,7 @@ class DAG(object):
self._channel_size = channel_size
self._build_dag_each_worker = build_dag_each_worker
self._tracer = tracer
self._channel_recv_frist_arrive = channel_recv_frist_arrive
if not self._is_thread_op:
self._manager = PipelineProcSyncManager()
_LOGGER.info("[DAG] Succ init")
......@@ -543,10 +546,15 @@ class DAG(object):
channel = None
if self._is_thread_op:
channel = ThreadChannel(
name=name_gen.next(), maxsize=self._channel_size)
name=name_gen.next(),
maxsize=self._channel_size,
channel_recv_frist_arrive=self._channel_recv_frist_arrive)
else:
channel = ProcessChannel(
self._manager, name=name_gen.next(), maxsize=self._channel_size)
self._manager,
name=name_gen.next(),
maxsize=self._channel_size,
channel_recv_frist_arrive=self._channel_recv_frist_arrive)
_LOGGER.debug("[DAG] Generate channel: {}".format(channel.name))
return channel
......
......@@ -506,7 +506,7 @@ class Op(object):
os._exit(-1)
channel.add_producer(self.name)
self._outputs.append(channel)
_LOGGER.info("op:{} add output_channel {}".format(self.name, channel))
_LOGGER.debug("op:{} add output_channel {}".format(self.name, channel))
def clean_output_channels(self):
self._outputs = []
......@@ -1333,6 +1333,8 @@ class Op(object):
break
end = int(round(_time() * 1000000))
in_time = end - start
_LOGGER.debug("op:{} in_time_end:{}".format(op_info_prefix,
time.time()))
# parse channeldata batch
try:
......@@ -1346,6 +1348,8 @@ class Op(object):
if len(parsed_data_dict) == 0:
# data in the whole batch is all error data
continue
_LOGGER.debug("op:{} parse_end:{}".format(op_info_prefix,
time.time()))
# print
front_cost = int(round(_time() * 1000000)) - start
......@@ -1360,6 +1364,8 @@ class Op(object):
= self._run_preprocess(parsed_data_dict, op_info_prefix, logid_dict)
end = profiler.record("prep#{}_1".format(op_info_prefix))
prep_time = end - start
_LOGGER.debug("op:{} preprocess_end:{}, cost:{}".format(
op_info_prefix, time.time(), prep_time))
try:
# put error requests into output channel, skip process and postprocess stage
for data_id, err_channeldata in err_channeldata_dict.items():
......@@ -1381,6 +1387,8 @@ class Op(object):
= self._run_process(preped_data_dict, op_info_prefix, skip_process_dict, logid_dict)
end = profiler.record("midp#{}_1".format(op_info_prefix))
midp_time = end - start
_LOGGER.debug("op:{} process_end:{}, cost:{}".format(
op_info_prefix, time.time(), midp_time))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels(
......@@ -1402,6 +1410,8 @@ class Op(object):
end = profiler.record("postp#{}_1".format(op_info_prefix))
postp_time = end - start
after_postp_time = _time()
_LOGGER.debug("op:{} postprocess_end:{}, cost:{}".format(
op_info_prefix, time.time(), postp_time))
try:
for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels(
......@@ -1690,9 +1700,10 @@ class RequestOp(Op):
else:
dict_data[name] = self.proto_tensor_2_numpy(one_tensor)
_LOGGER.debug("RequestOp unpack one request. log_id:{}, clientip:{} \
name:{}, method:{}".format(log_id, request.clientip, request.name,
request.method))
_LOGGER.info("RequestOp unpack one request. log_id:{}, clientip:{} \
name:{}, method:{}, time:{}"
.format(log_id, request.clientip, request.name,
request.method, time.time()))
return dict_data, log_id, None, ""
......
......@@ -14,6 +14,7 @@
# pylint: disable=doc-string-missing
import grpc
import sys
import time
import numpy as np
from numpy import *
import logging
......@@ -168,10 +169,11 @@ class PipelineClient(object):
"feed must be dict type with format: {name: value}.")
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
print("PipelineClient::predict pack_data time:{}".format(time.time()))
req = self._pack_request_package(feed_dict, pack_tensor_format, profile)
req.logid = log_id
if not asyn:
print("PipelineClient::predict before time:{}".format(time.time()))
resp = self._stub.inference(req)
return self._unpack_response_package(resp, fetch)
else:
......
......@@ -22,6 +22,7 @@ from contextlib import closing
import multiprocessing
import yaml
import io
import time
from .proto import pipeline_service_pb2_grpc, pipeline_service_pb2
from . import operator
......@@ -47,8 +48,9 @@ class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
_LOGGER.info("[PipelineServicer] succ init")
def inference(self, request, context):
_LOGGER.info("(log_id={}) inference request name:{} self.name:{}".
format(request.logid, request.name, self._name))
_LOGGER.info(
"(log_id={}) inference request name:{} self.name:{} time:{}".format(
request.logid, request.name, self._name, time.time()))
if request.name != "" and request.name != self._name:
_LOGGER.error("(log_id={}) name dismatch error. request.name:{},"
"server.name={}".format(request.logid, request.name,
......@@ -469,6 +471,7 @@ class ServerYamlConfChecker(object):
"channel_size": 0,
"is_thread_op": True,
"tracer": {},
"channel_recv_frist_arrive": False,
}
conf_type = {
......@@ -477,6 +480,7 @@ class ServerYamlConfChecker(object):
"use_profile": bool,
"channel_size": int,
"is_thread_op": bool,
"channel_recv_frist_arrive": bool,
}
conf_qualification = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册