提交 c26a8942 编写于 作者: B barriery 提交者: GitHub

Merge pull request #25 from barrierye/pipeling-log

use PriorityQueue instead of Queue in Channel && add logid in ServingClient
...@@ -37,6 +37,7 @@ message InferenceRequest { ...@@ -37,6 +37,7 @@ message InferenceRequest {
repeated string feed_var_names = 2; repeated string feed_var_names = 2;
repeated string fetch_var_names = 3; repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ]; required bool is_python = 4 [ default = false ];
required uint64 log_id = 5 [ default = 0 ];
}; };
message InferenceResponse { message InferenceResponse {
......
...@@ -21,6 +21,7 @@ option cc_generic_services = true; ...@@ -21,6 +21,7 @@ option cc_generic_services = true;
message RequestAndResponse { message RequestAndResponse {
required int32 a = 1; required int32 a = 1;
required float b = 2; required float b = 2;
required uint64 log_id = 3 [ default = 0 ];
}; };
service LoadGeneralModelService { service LoadGeneralModelService {
......
...@@ -268,6 +268,7 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -268,6 +268,7 @@ class PdsCodeGenerator : public CodeGenerator {
" const $input_name$* request,\n" " const $input_name$* request,\n"
" $output_name$* response,\n" " $output_name$* response,\n"
" google::protobuf::Closure* done) {\n" " google::protobuf::Closure* done) {\n"
" std::cout << \"WTFFFFFFFFFFFFFFFF\";\n"
" struct timeval tv;\n" " struct timeval tv;\n"
" gettimeofday(&tv, NULL);" " gettimeofday(&tv, NULL);"
" long start = tv.tv_sec * 1000000 + tv.tv_usec;", " long start = tv.tv_sec * 1000000 + tv.tv_usec;",
...@@ -1013,6 +1014,7 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -1013,6 +1014,7 @@ class PdsCodeGenerator : public CodeGenerator {
" brpc::ClosureGuard done_guard(done);\n" " brpc::ClosureGuard done_guard(done);\n"
" brpc::Controller* cntl = \n" " brpc::Controller* cntl = \n"
" static_cast<brpc::Controller*>(cntl_base);\n" " static_cast<brpc::Controller*>(cntl_base);\n"
" cntl->set_log_id(request->log_id());\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n" " ::baidu::paddle_serving::predictor::InferService* svr = \n"
" " " "
"::baidu::paddle_serving::predictor::InferServiceManager::instance(" "::baidu::paddle_serving::predictor::InferServiceManager::instance("
...@@ -1050,6 +1052,7 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -1050,6 +1052,7 @@ class PdsCodeGenerator : public CodeGenerator {
" brpc::ClosureGuard done_guard(done);\n" " brpc::ClosureGuard done_guard(done);\n"
" brpc::Controller* cntl = \n" " brpc::Controller* cntl = \n"
" static_cast<brpc::Controller*>(cntl_base);\n" " static_cast<brpc::Controller*>(cntl_base);\n"
" cntl->set_log_id(request->log_id());\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n" " ::baidu::paddle_serving::predictor::InferService* svr = \n"
" " " "
"::baidu::paddle_serving::predictor::InferServiceManager::instance(" "::baidu::paddle_serving::predictor::InferServiceManager::instance("
......
...@@ -192,14 +192,16 @@ public class Client { ...@@ -192,14 +192,16 @@ public class Client {
private InferenceRequest _packInferenceRequest( private InferenceRequest _packInferenceRequest(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException { Iterable<String> fetch,
long log_id) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>(); List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet()); feed_var_names.addAll(feed_batch.get(0).keySet());
InferenceRequest.Builder req_builder = InferenceRequest.newBuilder() InferenceRequest.Builder req_builder = InferenceRequest.newBuilder()
.addAllFeedVarNames(feed_var_names) .addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch) .addAllFetchVarNames(fetch)
.setIsPython(false); .setIsPython(false)
.setLogId(log_id);
for (HashMap<String, INDArray> feed_data: feed_batch) { for (HashMap<String, INDArray> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder(); FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) { for (String name: feed_var_names) {
...@@ -332,76 +334,151 @@ public class Client { ...@@ -332,76 +334,151 @@ public class Client {
public Map<String, INDArray> predict( public Map<String, INDArray> predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed, fetch, false); return predict(feed, fetch, false, 0);
}
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
long log_id) {
return predict(feed, fetch, false, log_id);
} }
public Map<String, HashMap<String, INDArray>> ensemble_predict( public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return ensemble_predict(feed, fetch, false); return ensemble_predict(feed, fetch, false, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
long log_id) {
return ensemble_predict(feed, fetch, false, log_id);
} }
public PredictFuture asyn_predict( public PredictFuture asyn_predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return asyn_predict(feed, fetch, false); return asyn_predict(feed, fetch, false, 0);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
long log_id) {
return asyn_predict(feed, fetch, false, log_id);
} }
public Map<String, INDArray> predict( public Map<String, INDArray> predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
return predict(feed, fetch, need_variant_tag, 0);
}
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
List<HashMap<String, INDArray>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>(); = new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed); feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag); return predict(feed_batch, fetch, need_variant_tag, log_id);
} }
public Map<String, HashMap<String, INDArray>> ensemble_predict( public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
return ensemble_predict(feed, fetch, need_variant_tag, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
List<HashMap<String, INDArray>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>(); = new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed); feed_batch.add(feed);
return ensemble_predict(feed_batch, fetch, need_variant_tag); return ensemble_predict(feed_batch, fetch, need_variant_tag, log_id);
} }
public PredictFuture asyn_predict( public PredictFuture asyn_predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
return asyn_predict(feed, fetch, need_variant_tag, 0);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
List<HashMap<String, INDArray>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>(); = new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed); feed_batch.add(feed);
return asyn_predict(feed_batch, fetch, need_variant_tag); return asyn_predict(feed_batch, fetch, need_variant_tag, log_id);
} }
public Map<String, INDArray> predict( public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed_batch, fetch, false); return predict(feed_batch, fetch, false, 0);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
long log_id) {
return predict(feed_batch, fetch, false, log_id);
} }
public Map<String, HashMap<String, INDArray>> ensemble_predict( public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return ensemble_predict(feed_batch, fetch, false); return ensemble_predict(feed_batch, fetch, false, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
long log_id) {
return ensemble_predict(feed_batch, fetch, false, log_id);
} }
public PredictFuture asyn_predict( public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return asyn_predict(feed_batch, fetch, false); return asyn_predict(feed_batch, fetch, false, 0);
}
public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
long log_id) {
return asyn_predict(feed_batch, fetch, false, log_id);
} }
public Map<String, INDArray> predict( public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
return predict(feed_batch, fetch, need_variant_tag, 0);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
try { try {
profiler_.record("java_prepro_0"); profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch); InferenceRequest req = _packInferenceRequest(
feed_batch, fetch, log_id);
profiler_.record("java_prepro_1"); profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0"); profiler_.record("java_client_infer_0");
...@@ -415,7 +492,7 @@ public class Client { ...@@ -415,7 +492,7 @@ public class Client {
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>( = new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet()); ensemble_result.entrySet());
if (list.size() != 1) { if (list.size() != 1) {
System.out.format("predict failed: please use ensemble_predict impl.\n"); System.out.format("Failed to predict: please use ensemble_predict impl.\n");
return null; return null;
} }
profiler_.record("java_postpro_1"); profiler_.record("java_postpro_1");
...@@ -423,7 +500,7 @@ public class Client { ...@@ -423,7 +500,7 @@ public class Client {
return list.get(0).getValue(); return list.get(0).getValue();
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("predict failed: %s\n", e.toString()); System.out.format("Failed to predict: %s\n", e.toString());
return null; return null;
} }
} }
...@@ -432,9 +509,18 @@ public class Client { ...@@ -432,9 +509,18 @@ public class Client {
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
return ensemble_predict(feed_batch, fetch, need_variant_tag, 0);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
try { try {
profiler_.record("java_prepro_0"); profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch); InferenceRequest req = _packInferenceRequest(
feed_batch, fetch, log_id);
profiler_.record("java_prepro_1"); profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0"); profiler_.record("java_client_infer_0");
...@@ -449,7 +535,7 @@ public class Client { ...@@ -449,7 +535,7 @@ public class Client {
return ensemble_result; return ensemble_result;
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("predict failed: %s\n", e.toString()); System.out.format("Failed to predict: %s\n", e.toString());
return null; return null;
} }
} }
...@@ -458,7 +544,16 @@ public class Client { ...@@ -458,7 +544,16 @@ public class Client {
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch); return asyn_predict(feed_batch, fetch, need_variant_tag, 0);
}
public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag,
long log_id) {
InferenceRequest req = _packInferenceRequest(
feed_batch, fetch, log_id);
ListenableFuture<InferenceResponse> future = futureStub_.inference(req); ListenableFuture<InferenceResponse> future = futureStub_.inference(req);
PredictFuture predict_future = new PredictFuture(future, PredictFuture predict_future = new PredictFuture(future,
(InferenceResponse resp) -> { (InferenceResponse resp) -> {
......
...@@ -37,6 +37,7 @@ message InferenceRequest { ...@@ -37,6 +37,7 @@ message InferenceRequest {
repeated string feed_var_names = 2; repeated string feed_var_names = 2;
repeated string fetch_var_names = 3; repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ]; required bool is_python = 4 [ default = false ];
required uint64 log_id = 5 [ default = 0 ];
}; };
message InferenceResponse { message InferenceResponse {
......
...@@ -466,10 +466,11 @@ class MultiLangClient(object): ...@@ -466,10 +466,11 @@ class MultiLangClient(object):
if var.is_lod_tensor: if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name) self.lod_tensor_set_.add(var.alias_name)
def _pack_inference_request(self, feed, fetch, is_python): def _pack_inference_request(self, feed, fetch, is_python, log_id):
req = multi_lang_general_model_service_pb2.InferenceRequest() req = multi_lang_general_model_service_pb2.InferenceRequest()
req.fetch_var_names.extend(fetch) req.fetch_var_names.extend(fetch)
req.is_python = is_python req.is_python = is_python
req.log_id = log_id
feed_batch = None feed_batch = None
if isinstance(feed, dict): if isinstance(feed, dict):
feed_batch = [feed] feed_batch = [feed]
...@@ -602,12 +603,13 @@ class MultiLangClient(object): ...@@ -602,12 +603,13 @@ class MultiLangClient(object):
fetch, fetch,
need_variant_tag=False, need_variant_tag=False,
asyn=False, asyn=False,
is_python=True): is_python=True,
log_id=0):
if not asyn: if not asyn:
try: try:
self.profile_.record('py_prepro_0') self.profile_.record('py_prepro_0')
req = self._pack_inference_request( req = self._pack_inference_request(
feed, fetch, is_python=is_python) feed, fetch, is_python=is_python, log_id=log_id)
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0') self.profile_.record('py_client_infer_0')
...@@ -626,7 +628,8 @@ class MultiLangClient(object): ...@@ -626,7 +628,8 @@ class MultiLangClient(object):
except grpc.RpcError as e: except grpc.RpcError as e:
return {"serving_status_code": e.code()} return {"serving_status_code": e.code()}
else: else:
req = self._pack_inference_request(feed, fetch, is_python=is_python) req = self._pack_inference_request(
feed, fetch, is_python=is_python, log_id=log_id)
call_future = self.stub_.Inference.future( call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_) req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture( return MultiLangPredictFuture(
......
...@@ -502,6 +502,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -502,6 +502,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
feed_names = list(request.feed_var_names) feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names) fetch_names = list(request.fetch_var_names)
is_python = request.is_python is_python = request.is_python
log_id = request.log_id
feed_batch = [] feed_batch = []
for feed_inst in request.insts: for feed_inst in request.insts:
feed_dict = {} feed_dict = {}
...@@ -530,7 +531,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -530,7 +531,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data.shape = list(feed_inst.tensor_array[idx].shape) data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data feed_dict[name] = data
feed_batch.append(feed_dict) feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python return feed_batch, fetch_names, is_python, log_id
def _pack_inference_response(self, ret, fetch_names, is_python): def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.InferenceResponse() resp = multi_lang_general_model_service_pb2.InferenceResponse()
...@@ -583,10 +584,13 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -583,10 +584,13 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
return resp return resp
def Inference(self, request, context): def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request( feed_dict, fetch_names, is_python, log_id = \
request) self._unpack_inference_request(request)
ret = self.bclient_.predict( ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) feed=feed_dict,
fetch=fetch_names,
need_variant_tag=True,
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python) return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context): def GetClientConfig(self, request, context):
......
...@@ -556,6 +556,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -556,6 +556,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
feed_names = list(request.feed_var_names) feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names) fetch_names = list(request.fetch_var_names)
is_python = request.is_python is_python = request.is_python
log_id = request.log_id
feed_batch = [] feed_batch = []
for feed_inst in request.insts: for feed_inst in request.insts:
feed_dict = {} feed_dict = {}
...@@ -637,10 +638,13 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -637,10 +638,13 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
return resp return resp
def Inference(self, request, context): def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request( feed_dict, fetch_names, is_python, log_id \
request) = self._unpack_inference_request(request)
ret = self.bclient_.predict( ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) feed=feed_dict,
fetch=fetch_names,
need_variant_tag=True,
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python) return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context): def GetClientConfig(self, request, context):
......
...@@ -181,6 +181,14 @@ class ChannelData(object): ...@@ -181,6 +181,14 @@ class ChannelData(object):
os._exit(-1) os._exit(-1)
return feed return feed
def __cmp__(self, other):
if self.id < other.id:
return -1
elif self.id == other.id:
return 0
else:
return 1
def __str__(self): def __str__(self):
return "type[{}], ecode[{}], id[{}]".format( return "type[{}], ecode[{}], id[{}]".format(
ChannelDataType(self.datatype).name, self.ecode, self.id) ChannelDataType(self.datatype).name, self.ecode, self.id)
...@@ -222,7 +230,7 @@ class ProcessChannel(object): ...@@ -222,7 +230,7 @@ class ProcessChannel(object):
# see more: # see more:
# - https://bugs.python.org/issue18277 # - https://bugs.python.org/issue18277
# - https://hg.python.org/cpython/rev/860fc6a2bd21 # - https://hg.python.org/cpython/rev/860fc6a2bd21
self._que = manager.Queue(maxsize=maxsize) self._que = manager.PriorityQueue(maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self.name = name self.name = name
self._stop = manager.Value('i', 0) self._stop = manager.Value('i', 0)
...@@ -489,7 +497,7 @@ class ProcessChannel(object): ...@@ -489,7 +497,7 @@ class ProcessChannel(object):
self._cv.notify_all() self._cv.notify_all()
class ThreadChannel(Queue.Queue): class ThreadChannel(Queue.PriorityQueue):
""" """
(Thread version)The channel used for communication between Ops. (Thread version)The channel used for communication between Ops.
......
...@@ -29,7 +29,7 @@ from .operator import Op, RequestOp, ResponseOp, VirtualOp ...@@ -29,7 +29,7 @@ from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import (ThreadChannel, ProcessChannel, ChannelData, from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataEcode, ChannelDataType, ChannelStopError) ChannelDataEcode, ChannelDataType, ChannelStopError)
from .profiler import TimeProfiler, PerformanceTracer from .profiler import TimeProfiler, PerformanceTracer
from .util import NameGenerator from .util import NameGenerator, ThreadIdGenerator, PipelineProcSyncManager
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
...@@ -74,9 +74,9 @@ class DAGExecutor(object): ...@@ -74,9 +74,9 @@ class DAGExecutor(object):
if self._tracer is not None: if self._tracer is not None:
self._tracer.start() self._tracer.start()
self._id_lock = threading.Lock() self._id_generator = ThreadIdGenerator(
self._id_counter = 0 max_id=1000000000000000000, base_counter=0, step=1)
self._reset_max_id = 1000000000000000000
self._cv_pool = {} self._cv_pool = {}
self._cv_for_cv_pool = threading.Condition() self._cv_for_cv_pool = threading.Condition()
self._fetch_buffer = {} self._fetch_buffer = {}
...@@ -98,13 +98,7 @@ class DAGExecutor(object): ...@@ -98,13 +98,7 @@ class DAGExecutor(object):
_LOGGER.info("[DAG Executor] Stop") _LOGGER.info("[DAG Executor] Stop")
def _get_next_data_id(self): def _get_next_data_id(self):
data_id = None data_id = self._id_generator.next()
with self._id_lock:
if self._id_counter >= self._reset_max_id:
_LOGGER.info("[DAG Executor] Reset request id")
self._id_counter -= self._reset_max_id
data_id = self._id_counter
self._id_counter += 1
cond_v = threading.Condition() cond_v = threading.Condition()
with self._cv_for_cv_pool: with self._cv_for_cv_pool:
self._cv_pool[data_id] = cond_v self._cv_pool[data_id] = cond_v
...@@ -330,7 +324,7 @@ class DAG(object): ...@@ -330,7 +324,7 @@ class DAG(object):
self._build_dag_each_worker = build_dag_each_worker self._build_dag_each_worker = build_dag_each_worker
self._tracer = tracer self._tracer = tracer
if not self._is_thread_op: if not self._is_thread_op:
self._manager = multiprocessing.Manager() self._manager = PipelineProcSyncManager()
_LOGGER.info("[DAG] Succ init") _LOGGER.info("[DAG] Succ init")
def get_use_ops(self, response_op): def get_use_ops(self, response_op):
......
...@@ -40,18 +40,11 @@ logger_config = { ...@@ -40,18 +40,11 @@ logger_config = {
"format": "%(asctime)s %(message)s", "format": "%(asctime)s %(message)s",
}, },
}, },
"filters": {
"info_only_filter": {
"()": SectionLevelFilter,
"levels": [logging.INFO],
},
},
"handlers": { "handlers": {
"f_pipeline.log": { "f_pipeline.log": {
"class": "logging.FileHandler", "class": "logging.FileHandler",
"level": "INFO", "level": "INFO",
"formatter": "normal_fmt", "formatter": "normal_fmt",
"filters": ["info_only_filter"],
"filename": os.path.join(log_dir, "pipeline.log"), "filename": os.path.join(log_dir, "pipeline.log"),
}, },
"f_pipeline.log.wf": { "f_pipeline.log.wf": {
......
...@@ -22,6 +22,7 @@ import logging ...@@ -22,6 +22,7 @@ import logging
import func_timeout import func_timeout
import os import os
import sys import sys
import collections
import numpy as np import numpy as np
from numpy import * from numpy import *
...@@ -127,7 +128,7 @@ class Op(object): ...@@ -127,7 +128,7 @@ class Op(object):
fetch_names): fetch_names):
if self.with_serving == False: if self.with_serving == False:
_LOGGER.info("Op({}) has no client (and it also do not " _LOGGER.info("Op({}) has no client (and it also do not "
"run the process function".format(self.name)) "run the process function)".format(self.name))
return None return None
if client_type == 'brpc': if client_type == 'brpc':
client = Client() client = Client()
...@@ -199,7 +200,7 @@ class Op(object): ...@@ -199,7 +200,7 @@ class Op(object):
(_, input_dict), = input_dicts.items() (_, input_dict), = input_dicts.items()
return input_dict return input_dict
def process(self, feed_batch): def process(self, feed_batch, typical_logid):
err, err_info = ChannelData.check_batch_npdata(feed_batch) err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0: if err != 0:
_LOGGER.critical( _LOGGER.critical(
...@@ -207,7 +208,7 @@ class Op(object): ...@@ -207,7 +208,7 @@ class Op(object):
"preprocess func.".format(err_info))) "preprocess func.".format(err_info)))
os._exit(-1) os._exit(-1)
call_result = self.client.predict( call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names) feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
if isinstance(self.client, MultiLangClient): if isinstance(self.client, MultiLangClient):
if call_result is None or call_result["serving_status_code"] != 0: if call_result is None or call_result["serving_status_code"] != 0:
return None return None
...@@ -294,8 +295,8 @@ class Op(object): ...@@ -294,8 +295,8 @@ class Op(object):
def _run_preprocess(self, parsed_data_dict, op_info_prefix): def _run_preprocess(self, parsed_data_dict, op_info_prefix):
_LOGGER.debug("{} Running preprocess".format(op_info_prefix)) _LOGGER.debug("{} Running preprocess".format(op_info_prefix))
preped_data_dict = {} preped_data_dict = collections.OrderedDict()
err_channeldata_dict = {} err_channeldata_dict = collections.OrderedDict()
for data_id, parsed_data in parsed_data_dict.items(): for data_id, parsed_data in parsed_data_dict.items():
preped_data, error_channeldata = None, None preped_data, error_channeldata = None, None
try: try:
...@@ -326,57 +327,68 @@ class Op(object): ...@@ -326,57 +327,68 @@ class Op(object):
def _run_process(self, preped_data_dict, op_info_prefix): def _run_process(self, preped_data_dict, op_info_prefix):
_LOGGER.debug("{} Running process".format(op_info_prefix)) _LOGGER.debug("{} Running process".format(op_info_prefix))
midped_data_dict = {} midped_data_dict = collections.OrderedDict()
err_channeldata_dict = {} err_channeldata_dict = collections.OrderedDict()
if self.with_serving: if self.with_serving:
data_ids = preped_data_dict.keys() data_ids = preped_data_dict.keys()
typical_logid = data_ids[0]
if len(data_ids) != 1:
for data_id in data_ids:
_LOGGER.info(
"(logid={}) {} During access to PaddleServingService,"
" we selected logid={} (from batch: {}) as a "
"representative for logging.".format(
data_id, op_info_prefix, typical_logid, data_ids))
feed_batch = [preped_data_dict[data_id] for data_id in data_ids] feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
midped_batch = None midped_batch = None
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
if self._timeout <= 0: if self._timeout <= 0:
try: try:
midped_batch = self.process(feed_batch) midped_batch = self.process(feed_batch, typical_logid)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = "{} Failed to process(batch: {}): {}".format( error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
op_info_prefix, data_ids, e) typical_logid, op_info_prefix, data_ids, e)
_LOGGER.error(error_info, exc_info=True) _LOGGER.error(error_info, exc_info=True)
else: else:
for i in range(self._retry): for i in range(self._retry):
try: try:
midped_batch = func_timeout.func_timeout( midped_batch = func_timeout.func_timeout(
self._timeout, self.process, args=(feed_batch, )) self._timeout,
self.process,
args=(feed_batch, typical_logid))
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
error_info = "{} Failed to process(batch: {}): " \ error_info = "(logid={}) {} Failed to process(batch: {}): " \
"exceeded retry count.".format( "exceeded retry count.".format(
op_info_prefix, data_ids) typical_logid, op_info_prefix, data_ids)
_LOGGER.error(error_info) _LOGGER.error(error_info)
else: else:
_LOGGER.warning( _LOGGER.warning(
"{} Failed to process(batch: {}): timeout, and retrying({}/{})" "(logid={}) {} Failed to process(batch: {}): timeout,"
.format(op_info_prefix, data_ids, i + 1, " and retrying({}/{})...".format(
self._retry)) typical_logid, op_info_prefix, data_ids, i +
1, self._retry))
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = "{} Failed to process(batch: {}): {}".format( error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
op_info_prefix, data_ids, e) typical_logid, op_info_prefix, data_ids, e)
_LOGGER.error(error_info, exc_info=True) _LOGGER.error(error_info, exc_info=True)
break break
else: else:
break break
if ecode != ChannelDataEcode.OK.value: if ecode != ChannelDataEcode.OK.value:
for data_id in data_ids: for data_id in data_ids:
_LOGGER.error("(logid={}) {}".format(data_id, error_info))
err_channeldata_dict[data_id] = ChannelData( err_channeldata_dict[data_id] = ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id) ecode=ecode, error_info=error_info, data_id=data_id)
elif midped_batch is None: elif midped_batch is None:
# op client return None # op client return None
error_info = "{} Failed to predict, please check if PaddleServingService" \ error_info = "(logid={}) {} Failed to predict, please check if " \
" is working properly.".format(op_info_prefix) "PaddleServingService is working properly.".format(
typical_logid, op_info_prefix)
_LOGGER.error(error_info)
for data_id in data_ids: for data_id in data_ids:
_LOGGER.error("(logid={}) {}".format(data_id, error_info))
err_channeldata_dict[data_id] = ChannelData( err_channeldata_dict[data_id] = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value, ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=error_info, error_info=error_info,
...@@ -396,8 +408,8 @@ class Op(object): ...@@ -396,8 +408,8 @@ class Op(object):
def _run_postprocess(self, parsed_data_dict, midped_data_dict, def _run_postprocess(self, parsed_data_dict, midped_data_dict,
op_info_prefix): op_info_prefix):
_LOGGER.debug("{} Running postprocess".format(op_info_prefix)) _LOGGER.debug("{} Running postprocess".format(op_info_prefix))
postped_data_dict = {} postped_data_dict = collections.OrderedDict()
err_channeldata_dict = {} err_channeldata_dict = collections.OrderedDict()
for data_id, midped_data in midped_data_dict.items(): for data_id, midped_data in midped_data_dict.items():
postped_data, err_channeldata = None, None postped_data, err_channeldata = None, None
try: try:
...@@ -476,7 +488,7 @@ class Op(object): ...@@ -476,7 +488,7 @@ class Op(object):
yield batch yield batch
def _parse_channeldata_batch(self, batch, output_channels): def _parse_channeldata_batch(self, batch, output_channels):
parsed_data_dict = {} parsed_data_dict = collections.OrderedDict()
need_profile_dict = {} need_profile_dict = {}
profile_dict = {} profile_dict = {}
for channeldata_dict in batch: for channeldata_dict in batch:
......
...@@ -13,13 +13,100 @@ ...@@ -13,13 +13,100 @@
# limitations under the License. # limitations under the License.
import sys import sys
import logging
import threading
import multiprocessing
if sys.version_info.major == 2:
import Queue
from Queue import PriorityQueue
elif sys.version_info.major == 3:
import queue as Queue
from queue import PriorityQueue
else:
raise Exception("Error Python version")
_LOGGER = logging.getLogger(__name__)
class NameGenerator(object): class NameGenerator(object):
# use unsafe-id-generator
def __init__(self, prefix): def __init__(self, prefix):
self._idx = -1 self._idx = -1
self._prefix = prefix self._prefix = prefix
self._id_generator = UnsafeIdGenerator(1000000000000000000)
def next(self):
next_id = self._id_generator.next()
return "{}{}".format(self._prefix, next_id)
class UnsafeIdGenerator(object):
def __init__(self, max_id, base_counter=0, step=1):
self._base_counter = base_counter
self._counter = self._base_counter
self._step = step
self._max_id = max_id # for reset
def next(self):
if self._counter >= self._max_id:
self._counter = self._base_counter
_LOGGER.info("Reset Id: {}".format(self._counter))
next_id = self._counter
self._counter += self._step
return next_id
class ThreadIdGenerator(UnsafeIdGenerator):
def __init__(self, max_id, base_counter=0, step=1, lock=None):
# if you want to use your lock, you may need to use Reentrant-Lock
self._lock = lock
if self._lock is None:
self._lock = threading.Lock()
super(ThreadIdGenerator, self).__init__(max_id, base_counter, step)
def next(self): def next(self):
self._idx += 1 next_id = None
return "{}{}".format(self._prefix, self._idx) with self._lock:
if self._counter >= self._max_id:
self._counter = self._base_counter
_LOGGER.info("Reset Id: {}".format(self._counter))
next_id = self._counter
self._counter += self._step
return next_id
class ProcessIdGenerator(UnsafeIdGenerator):
def __init__(self, max_id, base_counter=0, step=1, lock=None):
# if you want to use your lock, you may need to use Reentrant-Lock
self._lock = lock
if self._lock is None:
self._lock = multiprocessing.Lock()
self._base_counter = base_counter
self._counter = multiprocessing.Manager().Value('i', 0)
self._step = step
self._max_id = max_id
def next(self):
next_id = None
with self._lock:
if self._counter.value >= self._max_id:
self._counter.value = self._base_counter
_LOGGER.info("Reset Id: {}".format(self._counter.value))
next_id = self._counter.value
self._counter.value += self._step
return next_id
def PipelineProcSyncManager():
"""
add PriorityQueue into SyncManager, see more:
https://stackoverflow.com/questions/25324560/strange-queue-priorityqueue-behaviour-with-multiprocessing-in-python-2-7-6?answertab=active#tab-top
"""
class PipelineManager(multiprocessing.managers.SyncManager):
pass
PipelineManager.register("PriorityQueue", PriorityQueue)
m = PipelineManager()
m.start()
return m
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册