提交 c7eb18eb 编写于 作者: B barrierye

update demo for imdb model ensemble && bug fix

上级 67ca0f84
...@@ -13,26 +13,23 @@ ...@@ -13,26 +13,23 @@
# limitations under the License. # limitations under the License.
from paddle_serving_client.pyclient import PyClient from paddle_serving_client.pyclient import PyClient
import numpy as np import numpy as np
from paddle_serving_app.reader import IMDBDataset
from line_profiler import LineProfiler from line_profiler import LineProfiler
client = PyClient() client = PyClient()
client.connect('localhost:8080') client.connect('localhost:8080')
x = np.array(
[
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584,
0.6283, 0.4919, 0.1856, 0.0795, -0.0332
],
dtype='float')
lp = LineProfiler() lp = LineProfiler()
lp_wrapper = lp(client.predict) lp_wrapper = lp(client.predict)
words = 'i am very sad | 0'
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource('imdb.vocab')
for i in range(1): for i in range(1):
fetch_map = lp_wrapper(feed={"x": x}, fetch=["combine_op_output"]) word_ids, label = imdb_dataset.get_words_and_label(words)
# fetch_map = client.predict( fetch_map = lp_wrapper(
# feed={"x": x}, fetch_with_type={"combine_op_output": "float"}) feed={"words": word_ids}, fetch=["combined_prediction"])
print(fetch_map) print(fetch_map)
#lp.print_stats() #lp.print_stats()
...@@ -28,46 +28,42 @@ logging.basicConfig( ...@@ -28,46 +28,42 @@ logging.basicConfig(
class CombineOp(Op): class CombineOp(Op):
def preprocess(self, input_data): def preprocess(self, input_data):
cnt = 0 combined_prediction = 0
for op_name, channeldata in input_data.items(): for op_name, channeldata in input_data.items():
logging.debug("CombineOp preprocess: {}".format(op_name))
data = channeldata.parse() data = channeldata.parse()
cnt += data["price"] logging.info("{}: {}".format(op_name, data["prediction"]))
data = {"combine_op_output": cnt} combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2}
return data return data
read_op = Op(name="read", inputs=None) read_op = Op(name="read", inputs=None)
uci1_op = Op(name="uci1", bow_op = Op(name="bow",
inputs=[read_op], inputs=[read_op],
server_model="./uci_housing_model", server_model="imdb_bow_model",
server_port="9393", server_port="9393",
device="cpu", device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["price"], fetch_names=["prediction"],
concurrency=1, concurrency=1,
timeout=0.1, timeout=0.01,
retry=2) retry=2)
uci2_op = Op(name="uci2", cnn_op = Op(name="cnn",
inputs=[read_op], inputs=[read_op],
server_model="./uci_housing_model", server_model="imdb_cnn_model",
server_port="9292", server_port="9292",
device="cpu", device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt", client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292", server_name="127.0.0.1:9292",
fetch_names=["price"], fetch_names=["prediction"],
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=1) retry=1)
combine_op = CombineOp( combine_op = CombineOp(
name="combine", name="combine", inputs=[bow_op, cnn_op], concurrency=1, timeout=-1, retry=1)
inputs=[uci1_op, uci2_op],
concurrency=1,
timeout=-1,
retry=1)
pyserver = PyServer(profile=False, retry=1) pyserver = PyServer(profile=False, retry=1)
pyserver.add_ops([read_op, uci1_op, uci2_op, combine_op]) pyserver.add_ops([read_op, bow_op, cnn_op, combine_op])
pyserver.prepare_server(port=8080, worker_num=2) pyserver.prepare_server(port=8080, worker_num=2)
pyserver.run_server() pyserver.run_server()
...@@ -30,9 +30,10 @@ class PyClient(object): ...@@ -30,9 +30,10 @@ class PyClient(object):
def _pack_data_for_infer(self, feed_data): def _pack_data_for_infer(self, feed_data):
req = general_python_service_pb2.Request() req = general_python_service_pb2.Request()
for name, data in feed_data.items(): for name, data in feed_data.items():
if not isinstance(data, np.ndarray): if isinstance(data, list):
raise TypeError( data = np.array(data)
"only numpy array type is supported temporarily.") elif not isinstance(data, np.ndarray):
raise TypeError("only list and numpy array type is supported.")
req.feed_var_names.append(name) req.feed_var_names.append(name)
req.feed_insts.append(data.tobytes()) req.feed_insts.append(data.tobytes())
req.shape.append(np.array(data.shape, dtype="int32").tobytes()) req.shape.append(np.array(data.shape, dtype="int32").tobytes())
......
...@@ -462,10 +462,10 @@ class Op(object): ...@@ -462,10 +462,10 @@ class Op(object):
call_future = None call_future = None
error_info = None error_info = None
if self.with_serving(): if self.with_serving():
for i in range(self._retry): _profiler.record("{}{}-midp_0".format(self.name,
_profiler.record("{}{}-midp_0".format(self.name, concurrency_idx))
concurrency_idx)) if self._timeout > 0:
if self._timeout > 0: for i in range(self._retry):
try: try:
call_future = func_timeout.func_timeout( call_future = func_timeout.func_timeout(
self._timeout, self._timeout,
...@@ -475,19 +475,25 @@ class Op(object): ...@@ -475,19 +475,25 @@ class Op(object):
logging.error("error: timeout") logging.error("error: timeout")
error_info = "{}({}): timeout".format( error_info = "{}({}): timeout".format(
self.name, concurrency_idx) self.name, concurrency_idx)
if i + 1 < self._retry:
error_info = None
logging.warn(
self._log("warn: timeout, retry({})".
format(i + 1)))
except Exception as e: except Exception as e:
logging.error("error: {}".format(e)) logging.error("error: {}".format(e))
error_info = "{}({}): {}".format( error_info = "{}({}): {}".format(
self.name, concurrency_idx, e) self.name, concurrency_idx, e)
else: logging.warn(self._log(e))
call_future = self.midprocess(data) # TODO
_profiler.record("{}{}-midp_1".format(self.name, break
concurrency_idx)) else:
if i + 1 < self._retry: break
error_info = None else:
logging.warn( call_future = self.midprocess(data)
self._log("warn: timeout, retry({})".format(i +
1))) _profiler.record("{}{}-midp_1".format(self.name,
concurrency_idx))
_profiler.record("{}{}-postp_0".format(self.name, _profiler.record("{}{}-postp_0".format(self.name,
concurrency_idx)) concurrency_idx))
if error_info is not None: if error_info is not None:
...@@ -843,7 +849,6 @@ class PyServer(object): ...@@ -843,7 +849,6 @@ class PyServer(object):
return op.start(concurrency_idx) return op.start(concurrency_idx)
def _run_ops(self): def _run_ops(self):
#TODO
for op in self._ops: for op in self._ops:
op_concurrency = op.get_concurrency() op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format( logging.debug("run op: {}, op_concurrency: {}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册