提交 c26a8942 编写于 作者: B barriery 提交者: GitHub

Merge pull request #25 from barrierye/pipeling-log

use PriorityQueue instead of Queue in Channel && add logid in ServingClient
......@@ -37,6 +37,7 @@ message InferenceRequest {
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ];
required uint64 log_id = 5 [ default = 0 ];
};
message InferenceResponse {
......
......@@ -21,6 +21,7 @@ option cc_generic_services = true;
message RequestAndResponse {
required int32 a = 1;
required float b = 2;
required uint64 log_id = 3 [ default = 0 ];
};
service LoadGeneralModelService {
......
......@@ -268,6 +268,7 @@ class PdsCodeGenerator : public CodeGenerator {
" const $input_name$* request,\n"
" $output_name$* response,\n"
" google::protobuf::Closure* done) {\n"
" std::cout << \"WTFFFFFFFFFFFFFFFF\";\n"
" struct timeval tv;\n"
" gettimeofday(&tv, NULL);"
" long start = tv.tv_sec * 1000000 + tv.tv_usec;",
......@@ -1013,6 +1014,7 @@ class PdsCodeGenerator : public CodeGenerator {
" brpc::ClosureGuard done_guard(done);\n"
" brpc::Controller* cntl = \n"
" static_cast<brpc::Controller*>(cntl_base);\n"
" cntl->set_log_id(request->log_id());\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n"
" "
"::baidu::paddle_serving::predictor::InferServiceManager::instance("
......@@ -1050,6 +1052,7 @@ class PdsCodeGenerator : public CodeGenerator {
" brpc::ClosureGuard done_guard(done);\n"
" brpc::Controller* cntl = \n"
" static_cast<brpc::Controller*>(cntl_base);\n"
" cntl->set_log_id(request->log_id());\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n"
" "
"::baidu::paddle_serving::predictor::InferServiceManager::instance("
......
......@@ -192,14 +192,16 @@ public class Client {
private InferenceRequest _packInferenceRequest(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException {
Iterable<String> fetch,
long log_id) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet());
InferenceRequest.Builder req_builder = InferenceRequest.newBuilder()
.addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch)
.setIsPython(false);
.setIsPython(false)
.setLogId(log_id);
for (HashMap<String, INDArray> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) {
......@@ -332,76 +334,151 @@ public class Client {
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return predict(feed, fetch, false);
return predict(feed, fetch, false, 0);
}
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
long log_id) {
return predict(feed, fetch, false, log_id);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return ensemble_predict(feed, fetch, false);
return ensemble_predict(feed, fetch, false, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
long log_id) {
return ensemble_predict(feed, fetch, false, log_id);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return asyn_predict(feed, fetch, false);
return asyn_predict(feed, fetch, false, 0);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
long log_id) {
return asyn_predict(feed, fetch, false, log_id);
}
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
return predict(feed, fetch, need_variant_tag, 0);
}
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag);
return predict(feed_batch, fetch, need_variant_tag, log_id);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
return ensemble_predict(feed, fetch, need_variant_tag, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return ensemble_predict(feed_batch, fetch, need_variant_tag);
return ensemble_predict(feed_batch, fetch, need_variant_tag, log_id);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
return asyn_predict(feed, fetch, need_variant_tag, 0);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return asyn_predict(feed_batch, fetch, need_variant_tag);
return asyn_predict(feed_batch, fetch, need_variant_tag, log_id);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return predict(feed_batch, fetch, false);
return predict(feed_batch, fetch, false, 0);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
long log_id) {
return predict(feed_batch, fetch, false, log_id);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return ensemble_predict(feed_batch, fetch, false);
return ensemble_predict(feed_batch, fetch, false, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
long log_id) {
return ensemble_predict(feed_batch, fetch, false, log_id);
}
public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return asyn_predict(feed_batch, fetch, false);
return asyn_predict(feed_batch, fetch, false, 0);
}
public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
long log_id) {
return asyn_predict(feed_batch, fetch, false, log_id);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
return predict(feed_batch, fetch, need_variant_tag, 0);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
try {
profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
InferenceRequest req = _packInferenceRequest(
feed_batch, fetch, log_id);
profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0");
......@@ -415,7 +492,7 @@ public class Client {
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("predict failed: please use ensemble_predict impl.\n");
System.out.format("Failed to predict: please use ensemble_predict impl.\n");
return null;
}
profiler_.record("java_postpro_1");
......@@ -423,7 +500,7 @@ public class Client {
return list.get(0).getValue();
} catch (StatusRuntimeException e) {
System.out.format("predict failed: %s\n", e.toString());
System.out.format("Failed to predict: %s\n", e.toString());
return null;
}
}
......@@ -432,9 +509,18 @@ public class Client {
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
return ensemble_predict(feed_batch, fetch, need_variant_tag, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
try {
profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
InferenceRequest req = _packInferenceRequest(
feed_batch, fetch, log_id);
profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0");
......@@ -449,7 +535,7 @@ public class Client {
return ensemble_result;
} catch (StatusRuntimeException e) {
System.out.format("predict failed: %s\n", e.toString());
System.out.format("Failed to predict: %s\n", e.toString());
return null;
}
}
......@@ -458,7 +544,16 @@ public class Client {
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
return asyn_predict(feed_batch, fetch, need_variant_tag, 0);
}
public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
InferenceRequest req = _packInferenceRequest(
feed_batch, fetch, log_id);
ListenableFuture<InferenceResponse> future = futureStub_.inference(req);
PredictFuture predict_future = new PredictFuture(future,
(InferenceResponse resp) -> {
......
......@@ -37,6 +37,7 @@ message InferenceRequest {
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ];
required uint64 log_id = 5 [ default = 0 ];
};
message InferenceResponse {
......
......@@ -466,10 +466,11 @@ class MultiLangClient(object):
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _pack_inference_request(self, feed, fetch, is_python):
def _pack_inference_request(self, feed, fetch, is_python, log_id):
req = multi_lang_general_model_service_pb2.InferenceRequest()
req.fetch_var_names.extend(fetch)
req.is_python = is_python
req.log_id = log_id
feed_batch = None
if isinstance(feed, dict):
feed_batch = [feed]
......@@ -602,12 +603,13 @@ class MultiLangClient(object):
fetch,
need_variant_tag=False,
asyn=False,
is_python=True):
is_python=True,
log_id=0):
if not asyn:
try:
self.profile_.record('py_prepro_0')
req = self._pack_inference_request(
feed, fetch, is_python=is_python)
feed, fetch, is_python=is_python, log_id=log_id)
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
......@@ -626,7 +628,8 @@ class MultiLangClient(object):
except grpc.RpcError as e:
return {"serving_status_code": e.code()}
else:
req = self._pack_inference_request(feed, fetch, is_python=is_python)
req = self._pack_inference_request(
feed, fetch, is_python=is_python, log_id=log_id)
call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture(
......
......@@ -502,6 +502,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
log_id = request.log_id
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
......@@ -530,7 +531,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python
return feed_batch, fetch_names, is_python, log_id
def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.InferenceResponse()
......@@ -583,10 +584,13 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
return resp
def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request(
request)
feed_dict, fetch_names, is_python, log_id = \
self._unpack_inference_request(request)
ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
feed=feed_dict,
fetch=fetch_names,
need_variant_tag=True,
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
......
......@@ -556,6 +556,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
log_id = request.log_id
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
......@@ -637,10 +638,13 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
return resp
def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request(
request)
feed_dict, fetch_names, is_python, log_id \
= self._unpack_inference_request(request)
ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
feed=feed_dict,
fetch=fetch_names,
need_variant_tag=True,
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
......
......@@ -181,6 +181,14 @@ class ChannelData(object):
os._exit(-1)
return feed
def __cmp__(self, other):
if self.id < other.id:
return -1
elif self.id == other.id:
return 0
else:
return 1
def __str__(self):
return "type[{}], ecode[{}], id[{}]".format(
ChannelDataType(self.datatype).name, self.ecode, self.id)
......@@ -222,7 +230,7 @@ class ProcessChannel(object):
# see more:
# - https://bugs.python.org/issue18277
# - https://hg.python.org/cpython/rev/860fc6a2bd21
self._que = manager.Queue(maxsize=maxsize)
self._que = manager.PriorityQueue(maxsize=maxsize)
self._maxsize = maxsize
self.name = name
self._stop = manager.Value('i', 0)
......@@ -489,7 +497,7 @@ class ProcessChannel(object):
self._cv.notify_all()
class ThreadChannel(Queue.Queue):
class ThreadChannel(Queue.PriorityQueue):
"""
(Thread version)The channel used for communication between Ops.
......
......@@ -29,7 +29,7 @@ from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataEcode, ChannelDataType, ChannelStopError)
from .profiler import TimeProfiler, PerformanceTracer
from .util import NameGenerator
from .util import NameGenerator, ThreadIdGenerator, PipelineProcSyncManager
from .proto import pipeline_service_pb2
_LOGGER = logging.getLogger(__name__)
......@@ -74,9 +74,9 @@ class DAGExecutor(object):
if self._tracer is not None:
self._tracer.start()
self._id_lock = threading.Lock()
self._id_counter = 0
self._reset_max_id = 1000000000000000000
self._id_generator = ThreadIdGenerator(
max_id=1000000000000000000, base_counter=0, step=1)
self._cv_pool = {}
self._cv_for_cv_pool = threading.Condition()
self._fetch_buffer = {}
......@@ -98,13 +98,7 @@ class DAGExecutor(object):
_LOGGER.info("[DAG Executor] Stop")
def _get_next_data_id(self):
data_id = None
with self._id_lock:
if self._id_counter >= self._reset_max_id:
_LOGGER.info("[DAG Executor] Reset request id")
self._id_counter -= self._reset_max_id
data_id = self._id_counter
self._id_counter += 1
data_id = self._id_generator.next()
cond_v = threading.Condition()
with self._cv_for_cv_pool:
self._cv_pool[data_id] = cond_v
......@@ -330,7 +324,7 @@ class DAG(object):
self._build_dag_each_worker = build_dag_each_worker
self._tracer = tracer
if not self._is_thread_op:
self._manager = multiprocessing.Manager()
self._manager = PipelineProcSyncManager()
_LOGGER.info("[DAG] Succ init")
def get_use_ops(self, response_op):
......
......@@ -40,18 +40,11 @@ logger_config = {
"format": "%(asctime)s %(message)s",
},
},
"filters": {
"info_only_filter": {
"()": SectionLevelFilter,
"levels": [logging.INFO],
},
},
"handlers": {
"f_pipeline.log": {
"class": "logging.FileHandler",
"level": "INFO",
"formatter": "normal_fmt",
"filters": ["info_only_filter"],
"filename": os.path.join(log_dir, "pipeline.log"),
},
"f_pipeline.log.wf": {
......
......@@ -22,6 +22,7 @@ import logging
import func_timeout
import os
import sys
import collections
import numpy as np
from numpy import *
......@@ -127,7 +128,7 @@ class Op(object):
fetch_names):
if self.with_serving == False:
_LOGGER.info("Op({}) has no client (and it also do not "
"run the process function".format(self.name))
"run the process function)".format(self.name))
return None
if client_type == 'brpc':
client = Client()
......@@ -199,7 +200,7 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_batch):
def process(self, feed_batch, typical_logid):
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
_LOGGER.critical(
......@@ -207,7 +208,7 @@ class Op(object):
"preprocess func.".format(err_info)))
os._exit(-1)
call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names)
feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
if isinstance(self.client, MultiLangClient):
if call_result is None or call_result["serving_status_code"] != 0:
return None
......@@ -294,8 +295,8 @@ class Op(object):
def _run_preprocess(self, parsed_data_dict, op_info_prefix):
_LOGGER.debug("{} Running preprocess".format(op_info_prefix))
preped_data_dict = {}
err_channeldata_dict = {}
preped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
for data_id, parsed_data in parsed_data_dict.items():
preped_data, error_channeldata = None, None
try:
......@@ -326,57 +327,68 @@ class Op(object):
def _run_process(self, preped_data_dict, op_info_prefix):
_LOGGER.debug("{} Running process".format(op_info_prefix))
midped_data_dict = {}
err_channeldata_dict = {}
midped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
if self.with_serving:
data_ids = preped_data_dict.keys()
typical_logid = data_ids[0]
if len(data_ids) != 1:
for data_id in data_ids:
_LOGGER.info(
"(logid={}) {} During access to PaddleServingService,"
" we selected logid={} (from batch: {}) as a "
"representative for logging.".format(
data_id, op_info_prefix, typical_logid, data_ids))
feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
midped_batch = None
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_batch = self.process(feed_batch)
midped_batch = self.process(feed_batch, typical_logid)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = "{} Failed to process(batch: {}): {}".format(
op_info_prefix, data_ids, e)
error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
typical_logid, op_info_prefix, data_ids, e)
_LOGGER.error(error_info, exc_info=True)
else:
for i in range(self._retry):
try:
midped_batch = func_timeout.func_timeout(
self._timeout, self.process, args=(feed_batch, ))
self._timeout,
self.process,
args=(feed_batch, typical_logid))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
error_info = "{} Failed to process(batch: {}): " \
error_info = "(logid={}) {} Failed to process(batch: {}): " \
"exceeded retry count.".format(
op_info_prefix, data_ids)
typical_logid, op_info_prefix, data_ids)
_LOGGER.error(error_info)
else:
_LOGGER.warning(
"{} Failed to process(batch: {}): timeout, and retrying({}/{})"
.format(op_info_prefix, data_ids, i + 1,
self._retry))
"(logid={}) {} Failed to process(batch: {}): timeout,"
" and retrying({}/{})...".format(
typical_logid, op_info_prefix, data_ids, i +
1, self._retry))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = "{} Failed to process(batch: {}): {}".format(
op_info_prefix, data_ids, e)
error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
typical_logid, op_info_prefix, data_ids, e)
_LOGGER.error(error_info, exc_info=True)
break
else:
break
if ecode != ChannelDataEcode.OK.value:
for data_id in data_ids:
_LOGGER.error("(logid={}) {}".format(data_id, error_info))
err_channeldata_dict[data_id] = ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id)
elif midped_batch is None:
# op client return None
error_info = "{} Failed to predict, please check if PaddleServingService" \
" is working properly.".format(op_info_prefix)
error_info = "(logid={}) {} Failed to predict, please check if " \
"PaddleServingService is working properly.".format(
typical_logid, op_info_prefix)
_LOGGER.error(error_info)
for data_id in data_ids:
_LOGGER.error("(logid={}) {}".format(data_id, error_info))
err_channeldata_dict[data_id] = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=error_info,
......@@ -396,8 +408,8 @@ class Op(object):
def _run_postprocess(self, parsed_data_dict, midped_data_dict,
op_info_prefix):
_LOGGER.debug("{} Running postprocess".format(op_info_prefix))
postped_data_dict = {}
err_channeldata_dict = {}
postped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
for data_id, midped_data in midped_data_dict.items():
postped_data, err_channeldata = None, None
try:
......@@ -476,7 +488,7 @@ class Op(object):
yield batch
def _parse_channeldata_batch(self, batch, output_channels):
parsed_data_dict = {}
parsed_data_dict = collections.OrderedDict()
need_profile_dict = {}
profile_dict = {}
for channeldata_dict in batch:
......
......@@ -13,13 +13,100 @@
# limitations under the License.
import sys
import logging
import threading
import multiprocessing
if sys.version_info.major == 2:
import Queue
from Queue import PriorityQueue
elif sys.version_info.major == 3:
import queue as Queue
from queue import PriorityQueue
else:
raise Exception("Error Python version")
_LOGGER = logging.getLogger(__name__)
class NameGenerator(object):
# use unsafe-id-generator
def __init__(self, prefix):
self._idx = -1
self._prefix = prefix
self._id_generator = UnsafeIdGenerator(1000000000000000000)
def next(self):
next_id = self._id_generator.next()
return "{}{}".format(self._prefix, next_id)
class UnsafeIdGenerator(object):
def __init__(self, max_id, base_counter=0, step=1):
self._base_counter = base_counter
self._counter = self._base_counter
self._step = step
self._max_id = max_id # for reset
def next(self):
if self._counter >= self._max_id:
self._counter = self._base_counter
_LOGGER.info("Reset Id: {}".format(self._counter))
next_id = self._counter
self._counter += self._step
return next_id
class ThreadIdGenerator(UnsafeIdGenerator):
def __init__(self, max_id, base_counter=0, step=1, lock=None):
# if you want to use your lock, you may need to use Reentrant-Lock
self._lock = lock
if self._lock is None:
self._lock = threading.Lock()
super(ThreadIdGenerator, self).__init__(max_id, base_counter, step)
def next(self):
self._idx += 1
return "{}{}".format(self._prefix, self._idx)
next_id = None
with self._lock:
if self._counter >= self._max_id:
self._counter = self._base_counter
_LOGGER.info("Reset Id: {}".format(self._counter))
next_id = self._counter
self._counter += self._step
return next_id
class ProcessIdGenerator(UnsafeIdGenerator):
def __init__(self, max_id, base_counter=0, step=1, lock=None):
# if you want to use your lock, you may need to use Reentrant-Lock
self._lock = lock
if self._lock is None:
self._lock = multiprocessing.Lock()
self._base_counter = base_counter
self._counter = multiprocessing.Manager().Value('i', 0)
self._step = step
self._max_id = max_id
def next(self):
next_id = None
with self._lock:
if self._counter.value >= self._max_id:
self._counter.value = self._base_counter
_LOGGER.info("Reset Id: {}".format(self._counter.value))
next_id = self._counter.value
self._counter.value += self._step
return next_id
def PipelineProcSyncManager():
"""
add PriorityQueue into SyncManager, see more:
https://stackoverflow.com/questions/25324560/strange-queue-priorityqueue-behaviour-with-multiprocessing-in-python-2-7-6?answertab=active#tab-top
"""
class PipelineManager(multiprocessing.managers.SyncManager):
pass
PipelineManager.register("PriorityQueue", PriorityQueue)
m = PipelineManager()
m.start()
return m
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册