提交 b24dea39 编写于 作者: B barriery

add auto-batching-generator

上级 b91d81c1
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from time import time as _time
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
......@@ -43,7 +43,9 @@ class Op(object):
client_config=None,
concurrency=1,
timeout=-1,
retry=1):
retry=1,
batch_size=1,
auto_batchint_timeout=None):
if name is None:
name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique
......@@ -62,6 +64,11 @@ class Op(object):
self._input = None
self._outputs = []
self._batch_size = batch_size
self._auto_batchint_timeout = auto_batchint_timeout
if self._auto_batchint_timeout is not None and self._auto_batchint_timeout <= 0:
self._auto_batchint_timeout = None
self._server_use_profile = False
# only for multithread
......@@ -337,8 +344,32 @@ class Op(object):
dictdata=postped_data,
data_id=data_id)
return output_data, error_channeldata
def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout):
while True:
batch = []
while len(batch) == 0:
endtime = None
if timeout is not None:
endtime = _time() + timeout
for idx in range(batch_size):
try:
channeldata_dict = None
if timeout is not None:
remaining = endtime - _time()
if remaining <= 0.0:
_LOGGER.info(log("auto-batching timeout"))
break
channeldata_dict = input_channel.front(op_name, timeout)
else:
channeldata_dict = input_channel.front(op_name)
batch.append(channeldata_dict)
except ChannelTimeoutError:
_LOGGER.info(log("auto-batching timeout"))
break
yield batch
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
is_thread_op):
def get_log_func(op_info_prefix):
def log_func(info_str):
......@@ -351,65 +382,57 @@ class Op(object):
tid = threading.current_thread().ident
# init op
self.concurrency_idx = concurrency_idx
try:
if is_thread_op:
with self._for_init_op_lock:
if not self._succ_init_op:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
self._succ_init_op = True
self._succ_close_op = False
else:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(client_type, self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op()
self._initialize(is_thread_op)
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
batch_generator = self._auto_batching_generator(
input_channel=input_channel,
op_name=self.name,
batch_size=self._batch_size,
timeout=self._auto_batching_timeout)
while True:
#self._profiler_record("get#{}_0".format(op_info_prefix))
channeldata_dict_batch = None
try:
channeldata_dict = input_channel.front(self.name)
channeldata_dict_batch = next(batch_generator)
except ChannelStopError:
_LOGGER.debug(log("stop."))
if is_thread_op:
with self._for_close_op_lock:
if not self._succ_close_op:
self._profiler = None
self.client = None
self._succ_init_op = False
self._succ_close_op = True
self._finalize(is_thread_op)
break
#self._profiler_record("get#{}_1".format(op_info_prefix))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
(data_id, error_channeldata, parsed_data, client_need_profile,
profile_set) = self._parse_channeldata(channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
continue
# parse channeldata batch
try:
# parse channeldata batch
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
nor_dataid_list = []
err_dataid_list = []
nor_datas = {}
err_datas = {}
for channeldata_dict in channeldata_dict_batch:
(data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict)
if error_channeldata is None:
nor_dataid_list.append(data_id)
nor_datas[data_id] = {
"pd": parsed_data,
"np": client_need_profile,
"ps": profile_set,
}
else:
# error data in predecessor Op
try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
# preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix))
......@@ -463,7 +486,6 @@ class Op(object):
continue
# push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix))
try:
self._push_to_output_channels(
output_data,
......@@ -473,7 +495,45 @@ class Op(object):
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
#self._profiler_record("push#{}_1".format(op_info_prefix))
def _initialize(self, is_thread_op):
if is_thread_op:
with self._for_init_op_lock:
if not self._succ_init_op:
# for the threaded version of Op, each thread cannot get its concurrency_idx
self.concurrency_idx = None
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
self._succ_init_op = True
self._succ_close_op = False
else:
self.concurrency_idx = concurrency_idx
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op()
def _finalize(self, is_thread_op):
if is_thread_op:
with self._for_close_op_lock:
if not self._succ_close_op:
self._profiler = None
self.client = None
self._succ_init_op = False
self._succ_close_op = True
def _log(self, info):
return "{} {}".format(self.name, info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册