提交 a96b9922 编写于 作者: B barrierye

update code

上级 a1286818
......@@ -14,6 +14,8 @@
from paddle_serving_client.pyclient import PyClient
import numpy as np
from line_profiler import LineProfiler
client = PyClient()
client.connect('localhost:8080')
......@@ -24,7 +26,14 @@ x = np.array(
],
dtype='float')
lp = LineProfiler()
lp_wrapper = lp(client.predict)
for i in range(5):
fetch_map = client.predict(
fetch_map = lp_wrapper(
feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
# fetch_map = client.predict(
# feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
print(fetch_map)
lp.print_stats()
......@@ -72,7 +72,7 @@ cnn_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"],
concurrency=2)
concurrency=1)
bow_op = UciOp(
name="bow_op",
......@@ -86,7 +86,7 @@ bow_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"],
concurrency=2)
concurrency=1)
combine_op = CombineOp(
name="combine_op",
......@@ -94,7 +94,7 @@ combine_op = CombineOp(
in_dtype='float',
outputs=[out_channel],
out_dtype='float',
concurrency=2)
concurrency=1)
logging.info(read_channel.debug())
logging.info(combine_channel.debug())
......
......@@ -47,13 +47,11 @@ class Channel(Queue.Queue):
and can only be called during initialization.
"""
def __init__(self, name=None, maxsize=-1, timeout=None, batchsize=1):
def __init__(self, name=None, maxsize=-1, timeout=None):
Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self._name = name
#self._batchsize = batchsize
# self._pushbatch = []
self._cv = threading.Condition()
......@@ -81,14 +79,14 @@ class Channel(Queue.Queue):
self.get_consumers()))
def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization """
""" not thread safe, and can only be called during initialization. """
if op_name in self._producers:
raise ValueError(
self._log("producer({}) is already in channel".format(op_name)))
self._producers.append(op_name)
def add_consumer(self, op_name):
""" not thread safe, and can only be called during initialization """
""" not thread safe, and can only be called during initialization. """
if op_name in self._consumers:
raise ValueError(
self._log("consumer({}) is already in channel".format(op_name)))
......@@ -107,7 +105,7 @@ class Channel(Queue.Queue):
"expected number of producers to be greater than 0, but the it is 0."
))
elif len(self._producers) == 1:
self._cv.acquire()
with self._cv:
while True:
try:
self.put(data, timeout=0)
......@@ -115,7 +113,6 @@ class Channel(Queue.Queue):
except Queue.Empty:
self._cv.wait()
self._cv.notify_all()
self._cv.release()
logging.debug(self._log("{} push data succ!".format(op_name)))
return True
elif op_name is None:
......@@ -126,10 +123,13 @@ class Channel(Queue.Queue):
producer_num = len(self._producers)
data_id = data.id
put_data = None
self._cv.acquire()
with self._cv:
logging.debug(self._log("{} get lock ~".format(op_name)))
if data_id not in self._push_res:
self._push_res[data_id] = {name: None for name in self._producers}
self._push_res[data_id] = {
name: None
for name in self._producers
}
self._producer_res_count[data_id] = 0
self._push_res[data_id][op_name] = data
if self._producer_res_count[data_id] + 1 == producer_num:
......@@ -141,8 +141,8 @@ class Channel(Queue.Queue):
if put_data is None:
logging.debug(
self._log("{} push data succ, not not push to queue.".format(
op_name)))
self._log("{} push data succ, not not push to queue.".
format(op_name)))
else:
while True:
try:
......@@ -154,7 +154,6 @@ class Channel(Queue.Queue):
logging.debug(
self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all()
self._cv.release()
return True
def front(self, op_name=None):
......@@ -165,8 +164,8 @@ class Channel(Queue.Queue):
"expected number of consumers to be greater than 0, but the it is 0."
))
elif len(self._consumers) == 1:
self._cv.acquire()
resp = None
with self._cv:
while resp is None:
try:
resp = self.get(timeout=0)
......@@ -180,7 +179,7 @@ class Channel(Queue.Queue):
self._log(
"There are multiple consumers, so op_name cannot be None."))
self._cv.acquire()
with self._cv:
# data_idx = consumer_idx - base_idx
while self._consumers[op_name] - self._consumer_base_idx >= len(
self._front_res):
......@@ -211,7 +210,6 @@ class Channel(Queue.Queue):
self._idx_consumer_num[new_consumer_idx] += 1
self._cv.notify_all()
self._cv.release()
logging.debug(self._log("multi | {} get data succ!".format(op_name)))
return resp # reference, read only
......@@ -224,7 +222,6 @@ class Op(object):
in_dtype,
outputs,
out_dtype,
batchsize=1,
server_model=None,
server_port=None,
device=None,
......@@ -240,7 +237,6 @@ class Op(object):
self._in_dtype = in_dtype
self.set_outputs(outputs)
self._out_dtype = out_dtype
# self._batch_size = batchsize
self._client = None
if client_config is not None and \
server_name is not None and \
......@@ -395,10 +391,9 @@ class GeneralPythonService(
raise TypeError(
self._log('data must be ChannelData type, but get {}'.
format(type(data))))
self._cv.acquire()
with self._cv:
self._globel_resp_dict[data.id] = data
self._cv.notify_all()
self._cv.release()
def _get_next_id(self):
with self._id_lock:
......@@ -406,12 +401,12 @@ class GeneralPythonService(
return self._id_counter - 1
def _get_data_in_globel_resp_dict(self, data_id):
self._cv.acquire()
resp = None
with self._cv:
while data_id not in self._globel_resp_dict:
self._cv.wait()
resp = self._globel_resp_dict.pop(data_id)
self._cv.notify_all()
self._cv.release()
return resp
def _pack_data_for_infer(self, request):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册