提交 b24dea39 编写于 作者: B barriery

add auto-batching-generator

上级 b91d81c1
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from time import time as _time
import threading import threading
import multiprocessing import multiprocessing
from paddle_serving_client import MultiLangClient, Client from paddle_serving_client import MultiLangClient, Client
...@@ -43,7 +43,9 @@ class Op(object): ...@@ -43,7 +43,9 @@ class Op(object):
client_config=None, client_config=None,
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=1): retry=1,
batch_size=1,
auto_batchint_timeout=None):
if name is None: if name is None:
name = _op_name_gen.next() name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique self.name = name # to identify the type of OP, it must be globally unique
...@@ -62,6 +64,11 @@ class Op(object): ...@@ -62,6 +64,11 @@ class Op(object):
self._input = None self._input = None
self._outputs = [] self._outputs = []
self._batch_size = batch_size
self._auto_batchint_timeout = auto_batchint_timeout
if self._auto_batchint_timeout is not None and self._auto_batchint_timeout <= 0:
self._auto_batchint_timeout = None
self._server_use_profile = False self._server_use_profile = False
# only for multithread # only for multithread
...@@ -337,8 +344,32 @@ class Op(object): ...@@ -337,8 +344,32 @@ class Op(object):
dictdata=postped_data, dictdata=postped_data,
data_id=data_id) data_id=data_id)
return output_data, error_channeldata return output_data, error_channeldata
def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout):
while True:
batch = []
while len(batch) == 0:
endtime = None
if timeout is not None:
endtime = _time() + timeout
for idx in range(batch_size):
try:
channeldata_dict = None
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
_LOGGER.info(log("auto-batching timeout"))
break
channeldata_dict = input_channel.front(op_name, timeout)
else:
channeldata_dict = input_channel.front(op_name)
batch.append(channeldata_dict)
except ChannelTimeoutError:
_LOGGER.info(log("auto-batching timeout"))
break
yield batch
def _run(self, concurrency_idx, input_channel, output_channels, client_type, def _run(self, concurrency_idx, input_channel, output_channels, client_type,
is_thread_op): is_thread_op):
def get_log_func(op_info_prefix): def get_log_func(op_info_prefix):
def log_func(info_str): def log_func(info_str):
...@@ -351,65 +382,57 @@ class Op(object): ...@@ -351,65 +382,57 @@ class Op(object):
tid = threading.current_thread().ident tid = threading.current_thread().ident
# init op # init op
self.concurrency_idx = concurrency_idx
try: try:
if is_thread_op: self._initialize(is_thread_op)
with self._for_init_op_lock:
if not self._succ_init_op:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
self._succ_init_op = True
self._succ_close_op = False
else:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(client_type, self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op()
except Exception as e: except Exception as e:
_LOGGER.error(log(e)) _LOGGER.error(log(e))
os._exit(-1) os._exit(-1)
batch_generator = self._auto_batching_generator(
input_channel=input_channel,
op_name=self.name,
batch_size=self._batch_size,
timeout=self._auto_batching_timeout)
while True: while True:
#self._profiler_record("get#{}_0".format(op_info_prefix)) channeldata_dict_batch = None
try: try:
channeldata_dict = input_channel.front(self.name) channeldata_dict_batch = next(batch_generator)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("stop."))
if is_thread_op: self._finalize(is_thread_op)
with self._for_close_op_lock:
if not self._succ_close_op:
self._profiler = None
self.client = None
self._succ_init_op = False
self._succ_close_op = True
break break
#self._profiler_record("get#{}_1".format(op_info_prefix))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
(data_id, error_channeldata, parsed_data, client_need_profile, # parse channeldata batch
profile_set) = self._parse_channeldata(channeldata_dict) try:
# error data in predecessor Op # parse channeldata batch
if error_channeldata is not None: except ChannelStopError:
try: _LOGGER.debug(log("stop."))
# error_channeldata with profile info break
self._push_to_output_channels(error_channeldata, nor_dataid_list = []
output_channels) err_dataid_list = []
except ChannelStopError: nor_datas = {}
_LOGGER.debug(log("stop.")) err_datas = {}
break for channeldata_dict in channeldata_dict_batch:
continue (data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict)
if error_channeldata is None:
nor_dataid_list.append(data_id)
nor_datas[data_id] = {
"pd": parsed_data,
"np": client_need_profile,
"ps": profile_set,
}
else:
# error data in predecessor Op
try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
# preprecess # preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix)) self._profiler_record("prep#{}_0".format(op_info_prefix))
...@@ -463,7 +486,6 @@ class Op(object): ...@@ -463,7 +486,6 @@ class Op(object):
continue continue
# push data to channel (if run succ) # push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix))
try: try:
self._push_to_output_channels( self._push_to_output_channels(
output_data, output_data,
...@@ -473,7 +495,45 @@ class Op(object): ...@@ -473,7 +495,45 @@ class Op(object):
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("stop."))
break break
#self._profiler_record("push#{}_1".format(op_info_prefix))
def _initialize(self, is_thread_op):
if is_thread_op:
with self._for_init_op_lock:
if not self._succ_init_op:
# for the threaded version of Op, each thread cannot get its concurrency_idx
self.concurrency_idx = None
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
self._succ_init_op = True
self._succ_close_op = False
else:
self.concurrency_idx = concurrency_idx
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op()
def _finalize(self, is_thread_op):
if is_thread_op:
with self._for_close_op_lock:
if not self._succ_close_op:
self._profiler = None
self.client = None
self._succ_init_op = False
self._succ_close_op = True
def _log(self, info): def _log(self, info):
return "{} {}".format(self.name, info) return "{} {}".format(self.name, info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册