提交 c7eb18eb 编写于 作者: B barrierye

update demo for imdb model ensemble && bug fix

上级 67ca0f84
......@@ -13,26 +13,23 @@
# limitations under the License.
from paddle_serving_client.pyclient import PyClient
import numpy as np
from paddle_serving_app.reader import IMDBDataset
from line_profiler import LineProfiler
client = PyClient()
client.connect('localhost:8080')
x = np.array(
[
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584,
0.6283, 0.4919, 0.1856, 0.0795, -0.0332
],
dtype='float')
lp = LineProfiler()
lp_wrapper = lp(client.predict)
words = 'i am very sad | 0'
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource('imdb.vocab')
for i in range(1):
fetch_map = lp_wrapper(feed={"x": x}, fetch=["combine_op_output"])
# fetch_map = client.predict(
# feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
word_ids, label = imdb_dataset.get_words_and_label(words)
fetch_map = lp_wrapper(
feed={"words": word_ids}, fetch=["combined_prediction"])
print(fetch_map)
#lp.print_stats()
......@@ -28,46 +28,42 @@ logging.basicConfig(
class CombineOp(Op):
def preprocess(self, input_data):
cnt = 0
combined_prediction = 0
for op_name, channeldata in input_data.items():
logging.debug("CombineOp preprocess: {}".format(op_name))
data = channeldata.parse()
cnt += data["price"]
data = {"combine_op_output": cnt}
logging.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"]
data = {"combined_prediction": combined_prediction / 2}
return data
read_op = Op(name="read", inputs=None)
uci1_op = Op(name="uci1",
inputs=[read_op],
server_model="./uci_housing_model",
server_port="9393",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"],
concurrency=1,
timeout=0.1,
retry=2)
uci2_op = Op(name="uci2",
inputs=[read_op],
server_model="./uci_housing_model",
server_port="9292",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9292",
fetch_names=["price"],
concurrency=1,
timeout=-1,
retry=1)
bow_op = Op(name="bow",
inputs=[read_op],
server_model="imdb_bow_model",
server_port="9393",
device="cpu",
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["prediction"],
concurrency=1,
timeout=0.01,
retry=2)
cnn_op = Op(name="cnn",
inputs=[read_op],
server_model="imdb_cnn_model",
server_port="9292",
device="cpu",
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292",
fetch_names=["prediction"],
concurrency=1,
timeout=-1,
retry=1)
combine_op = CombineOp(
name="combine",
inputs=[uci1_op, uci2_op],
concurrency=1,
timeout=-1,
retry=1)
name="combine", inputs=[bow_op, cnn_op], concurrency=1, timeout=-1, retry=1)
pyserver = PyServer(profile=False, retry=1)
pyserver.add_ops([read_op, uci1_op, uci2_op, combine_op])
pyserver.add_ops([read_op, bow_op, cnn_op, combine_op])
pyserver.prepare_server(port=8080, worker_num=2)
pyserver.run_server()
......@@ -30,9 +30,10 @@ class PyClient(object):
def _pack_data_for_infer(self, feed_data):
req = general_python_service_pb2.Request()
for name, data in feed_data.items():
if not isinstance(data, np.ndarray):
raise TypeError(
"only numpy array type is supported temporarily.")
if isinstance(data, list):
data = np.array(data)
elif not isinstance(data, np.ndarray):
raise TypeError("only list and numpy array type is supported.")
req.feed_var_names.append(name)
req.feed_insts.append(data.tobytes())
req.shape.append(np.array(data.shape, dtype="int32").tobytes())
......
......@@ -462,10 +462,10 @@ class Op(object):
call_future = None
error_info = None
if self.with_serving():
for i in range(self._retry):
_profiler.record("{}{}-midp_0".format(self.name,
concurrency_idx))
if self._timeout > 0:
_profiler.record("{}{}-midp_0".format(self.name,
concurrency_idx))
if self._timeout > 0:
for i in range(self._retry):
try:
call_future = func_timeout.func_timeout(
self._timeout,
......@@ -475,19 +475,25 @@ class Op(object):
logging.error("error: timeout")
error_info = "{}({}): timeout".format(
self.name, concurrency_idx)
if i + 1 < self._retry:
error_info = None
logging.warn(
self._log("warn: timeout, retry({})".
format(i + 1)))
except Exception as e:
logging.error("error: {}".format(e))
error_info = "{}({}): {}".format(
self.name, concurrency_idx, e)
else:
call_future = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self.name,
concurrency_idx))
if i + 1 < self._retry:
error_info = None
logging.warn(
self._log("warn: timeout, retry({})".format(i +
1)))
logging.warn(self._log(e))
# TODO
break
else:
break
else:
call_future = self.midprocess(data)
_profiler.record("{}{}-midp_1".format(self.name,
concurrency_idx))
_profiler.record("{}{}-postp_0".format(self.name,
concurrency_idx))
if error_info is not None:
......@@ -843,7 +849,6 @@ class PyServer(object):
return op.start(concurrency_idx)
def _run_ops(self):
#TODO
for op in self._ops:
op_concurrency = op.get_concurrency()
logging.debug("run op: {}, op_concurrency: {}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册