提交 a96b9922 编写于 作者: B barrierye

update code

上级 a1286818
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
from paddle_serving_client.pyclient import PyClient from paddle_serving_client.pyclient import PyClient
import numpy as np import numpy as np
from line_profiler import LineProfiler
client = PyClient() client = PyClient()
client.connect('localhost:8080') client.connect('localhost:8080')
...@@ -24,7 +26,14 @@ x = np.array( ...@@ -24,7 +26,14 @@ x = np.array(
], ],
dtype='float') dtype='float')
lp = LineProfiler()
lp_wrapper = lp(client.predict)
for i in range(5): for i in range(5):
fetch_map = client.predict( fetch_map = lp_wrapper(
feed={"x": x}, fetch_with_type={"combine_op_output": "float"}) feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
# fetch_map = client.predict(
# feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
print(fetch_map) print(fetch_map)
lp.print_stats()
...@@ -72,7 +72,7 @@ cnn_op = UciOp( ...@@ -72,7 +72,7 @@ cnn_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"], fetch_names=["price"],
concurrency=2) concurrency=1)
bow_op = UciOp( bow_op = UciOp(
name="bow_op", name="bow_op",
...@@ -86,7 +86,7 @@ bow_op = UciOp( ...@@ -86,7 +86,7 @@ bow_op = UciOp(
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"], fetch_names=["price"],
concurrency=2) concurrency=1)
combine_op = CombineOp( combine_op = CombineOp(
name="combine_op", name="combine_op",
...@@ -94,7 +94,7 @@ combine_op = CombineOp( ...@@ -94,7 +94,7 @@ combine_op = CombineOp(
in_dtype='float', in_dtype='float',
outputs=[out_channel], outputs=[out_channel],
out_dtype='float', out_dtype='float',
concurrency=2) concurrency=1)
logging.info(read_channel.debug()) logging.info(read_channel.debug())
logging.info(combine_channel.debug()) logging.info(combine_channel.debug())
......
...@@ -47,13 +47,11 @@ class Channel(Queue.Queue): ...@@ -47,13 +47,11 @@ class Channel(Queue.Queue):
and can only be called during initialization. and can only be called during initialization.
""" """
def __init__(self, name=None, maxsize=-1, timeout=None, batchsize=1): def __init__(self, name=None, maxsize=-1, timeout=None):
Queue.Queue.__init__(self, maxsize=maxsize) Queue.Queue.__init__(self, maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self._name = name self._name = name
#self._batchsize = batchsize
# self._pushbatch = []
self._cv = threading.Condition() self._cv = threading.Condition()
...@@ -81,14 +79,14 @@ class Channel(Queue.Queue): ...@@ -81,14 +79,14 @@ class Channel(Queue.Queue):
self.get_consumers())) self.get_consumers()))
def add_producer(self, op_name): def add_producer(self, op_name):
""" not thread safe, and can only be called during initialization """ """ not thread safe, and can only be called during initialization. """
if op_name in self._producers: if op_name in self._producers:
raise ValueError( raise ValueError(
self._log("producer({}) is already in channel".format(op_name))) self._log("producer({}) is already in channel".format(op_name)))
self._producers.append(op_name) self._producers.append(op_name)
def add_consumer(self, op_name): def add_consumer(self, op_name):
""" not thread safe, and can only be called during initialization """ """ not thread safe, and can only be called during initialization. """
if op_name in self._consumers: if op_name in self._consumers:
raise ValueError( raise ValueError(
self._log("consumer({}) is already in channel".format(op_name))) self._log("consumer({}) is already in channel".format(op_name)))
...@@ -107,7 +105,7 @@ class Channel(Queue.Queue): ...@@ -107,7 +105,7 @@ class Channel(Queue.Queue):
"expected number of producers to be greater than 0, but the it is 0." "expected number of producers to be greater than 0, but the it is 0."
)) ))
elif len(self._producers) == 1: elif len(self._producers) == 1:
self._cv.acquire() with self._cv:
while True: while True:
try: try:
self.put(data, timeout=0) self.put(data, timeout=0)
...@@ -115,7 +113,6 @@ class Channel(Queue.Queue): ...@@ -115,7 +113,6 @@ class Channel(Queue.Queue):
except Queue.Empty: except Queue.Empty:
self._cv.wait() self._cv.wait()
self._cv.notify_all() self._cv.notify_all()
self._cv.release()
logging.debug(self._log("{} push data succ!".format(op_name))) logging.debug(self._log("{} push data succ!".format(op_name)))
return True return True
elif op_name is None: elif op_name is None:
...@@ -126,10 +123,13 @@ class Channel(Queue.Queue): ...@@ -126,10 +123,13 @@ class Channel(Queue.Queue):
producer_num = len(self._producers) producer_num = len(self._producers)
data_id = data.id data_id = data.id
put_data = None put_data = None
self._cv.acquire() with self._cv:
logging.debug(self._log("{} get lock ~".format(op_name))) logging.debug(self._log("{} get lock ~".format(op_name)))
if data_id not in self._push_res: if data_id not in self._push_res:
self._push_res[data_id] = {name: None for name in self._producers} self._push_res[data_id] = {
name: None
for name in self._producers
}
self._producer_res_count[data_id] = 0 self._producer_res_count[data_id] = 0
self._push_res[data_id][op_name] = data self._push_res[data_id][op_name] = data
if self._producer_res_count[data_id] + 1 == producer_num: if self._producer_res_count[data_id] + 1 == producer_num:
...@@ -141,8 +141,8 @@ class Channel(Queue.Queue): ...@@ -141,8 +141,8 @@ class Channel(Queue.Queue):
if put_data is None: if put_data is None:
logging.debug( logging.debug(
self._log("{} push data succ, not not push to queue.".format( self._log("{} push data succ, not not push to queue.".
op_name))) format(op_name)))
else: else:
while True: while True:
try: try:
...@@ -154,7 +154,6 @@ class Channel(Queue.Queue): ...@@ -154,7 +154,6 @@ class Channel(Queue.Queue):
logging.debug( logging.debug(
self._log("multi | {} push data succ!".format(op_name))) self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all() self._cv.notify_all()
self._cv.release()
return True return True
def front(self, op_name=None): def front(self, op_name=None):
...@@ -165,8 +164,8 @@ class Channel(Queue.Queue): ...@@ -165,8 +164,8 @@ class Channel(Queue.Queue):
"expected number of consumers to be greater than 0, but the it is 0." "expected number of consumers to be greater than 0, but the it is 0."
)) ))
elif len(self._consumers) == 1: elif len(self._consumers) == 1:
self._cv.acquire()
resp = None resp = None
with self._cv:
while resp is None: while resp is None:
try: try:
resp = self.get(timeout=0) resp = self.get(timeout=0)
...@@ -180,7 +179,7 @@ class Channel(Queue.Queue): ...@@ -180,7 +179,7 @@ class Channel(Queue.Queue):
self._log( self._log(
"There are multiple consumers, so op_name cannot be None.")) "There are multiple consumers, so op_name cannot be None."))
self._cv.acquire() with self._cv:
# data_idx = consumer_idx - base_idx # data_idx = consumer_idx - base_idx
while self._consumers[op_name] - self._consumer_base_idx >= len( while self._consumers[op_name] - self._consumer_base_idx >= len(
self._front_res): self._front_res):
...@@ -211,7 +210,6 @@ class Channel(Queue.Queue): ...@@ -211,7 +210,6 @@ class Channel(Queue.Queue):
self._idx_consumer_num[new_consumer_idx] += 1 self._idx_consumer_num[new_consumer_idx] += 1
self._cv.notify_all() self._cv.notify_all()
self._cv.release()
logging.debug(self._log("multi | {} get data succ!".format(op_name))) logging.debug(self._log("multi | {} get data succ!".format(op_name)))
return resp # reference, read only return resp # reference, read only
...@@ -224,7 +222,6 @@ class Op(object): ...@@ -224,7 +222,6 @@ class Op(object):
in_dtype, in_dtype,
outputs, outputs,
out_dtype, out_dtype,
batchsize=1,
server_model=None, server_model=None,
server_port=None, server_port=None,
device=None, device=None,
...@@ -240,7 +237,6 @@ class Op(object): ...@@ -240,7 +237,6 @@ class Op(object):
self._in_dtype = in_dtype self._in_dtype = in_dtype
self.set_outputs(outputs) self.set_outputs(outputs)
self._out_dtype = out_dtype self._out_dtype = out_dtype
# self._batch_size = batchsize
self._client = None self._client = None
if client_config is not None and \ if client_config is not None and \
server_name is not None and \ server_name is not None and \
...@@ -395,10 +391,9 @@ class GeneralPythonService( ...@@ -395,10 +391,9 @@ class GeneralPythonService(
raise TypeError( raise TypeError(
self._log('data must be ChannelData type, but get {}'. self._log('data must be ChannelData type, but get {}'.
format(type(data)))) format(type(data))))
self._cv.acquire() with self._cv:
self._globel_resp_dict[data.id] = data self._globel_resp_dict[data.id] = data
self._cv.notify_all() self._cv.notify_all()
self._cv.release()
def _get_next_id(self): def _get_next_id(self):
with self._id_lock: with self._id_lock:
...@@ -406,12 +401,12 @@ class GeneralPythonService( ...@@ -406,12 +401,12 @@ class GeneralPythonService(
return self._id_counter - 1 return self._id_counter - 1
def _get_data_in_globel_resp_dict(self, data_id): def _get_data_in_globel_resp_dict(self, data_id):
self._cv.acquire() resp = None
with self._cv:
while data_id not in self._globel_resp_dict: while data_id not in self._globel_resp_dict:
self._cv.wait() self._cv.wait()
resp = self._globel_resp_dict.pop(data_id) resp = self._globel_resp_dict.pop(data_id)
self._cv.notify_all() self._cv.notify_all()
self._cv.release()
return resp return resp
def _pack_data_for_infer(self, request): def _pack_data_for_infer(self, request):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册