提交 6f454bf3 编写于 作者: T TeslaZhao

Python pipeline mode supports tensor structure input and output

上级 b789a604
...@@ -289,16 +289,18 @@ class Client(object): ...@@ -289,16 +289,18 @@ class Client(object):
log_id=0): log_id=0):
self.profile_.record('py_prepro_0') self.profile_.record('py_prepro_0')
if feed is None or fetch is None: if feed is None:
raise ValueError("You should specify feed and fetch for prediction") raise ValueError("You should specify feed for prediction")
fetch_list = [] fetch_list = []
if isinstance(fetch, str): if isinstance(fetch, str):
fetch_list = [fetch] fetch_list = [fetch]
elif isinstance(fetch, list): elif isinstance(fetch, list):
fetch_list = fetch fetch_list = fetch
elif fetch == None:
pass
else: else:
raise ValueError("Fetch only accepts string and list of string") raise ValueError("Fetch only accepts string or list of string")
feed_batch = [] feed_batch = []
if isinstance(feed, dict): if isinstance(feed, dict):
...@@ -439,6 +441,8 @@ class Client(object): ...@@ -439,6 +441,8 @@ class Client(object):
model_engine_names = result_batch_handle.get_engine_names() model_engine_names = result_batch_handle.get_engine_names()
for mi, engine_name in enumerate(model_engine_names): for mi, engine_name in enumerate(model_engine_names):
result_map = {} result_map = {}
if len(fetch_names) == 0:
fetch_names = result_batch_handle.get_tensor_alias_names(mi)
# result map needs to be a numpy array # result map needs to be a numpy array
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int64_type: if self.fetch_names_to_type_[name] == int64_type:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册