提交 f4137508 编写于 作者: G guru4elephant 提交者: dongdaxiang

unify predict and batch predict interface and call batch_predict inside the predict function

上级 bafb844b
...@@ -155,44 +155,23 @@ class Client(object): ...@@ -155,44 +155,23 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format( raise SystemExit("The shape of feed tensor {} not match.".format(
key)) key))
def predict(self, feed={}, fetch=[]): def predict(self, feed=None, fetch=None):
int_slot = [] fetch_list = []
float_slot = [] if isinstance(fetch, str):
int_feed_names = [] fetch_list = [fetch]
float_feed_names = [] elif isinstance(fetch, list):
fetch_names = [] fetch_list = fetch
else:
for key in feed: raise ValueError("fetch only accepts string and list of string")
self.shape_check(feed, key)
if key not in self.feed_names_: feed_batch = []
continue if isinstance(feed, dict):
if self.feed_types_[key] == int_type: feed_batch.append(feed)
int_feed_names.append(key) elif isinstance(feed, list):
int_slot.append(feed[key]) feed_batch = feed
elif self.feed_types_[key] == float_type: else:
float_feed_names.append(key) raise ValueError("feed only accepts dict and list of dict")
float_slot.append(feed[key])
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
ret = self.client_handle_.predict(float_slot, float_feed_names,
int_slot, int_feed_names, fetch_names,
self.result_handle_, self.pid)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[
0]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
return result_map
def batch_predict(self, feed_batch=[], fetch=[]):
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
int_feed_names = [] int_feed_names = []
...@@ -200,10 +179,20 @@ class Client(object): ...@@ -200,10 +179,20 @@ class Client(object):
fetch_names = [] fetch_names = []
counter = 0 counter = 0
batch_size = len(feed_batch) batch_size = len(feed_batch)
for feed in feed_batch:
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"fetch names should not be empty or out of saved fetch list")
return {}
for feed_i in feed_batch:
int_slot = [] int_slot = []
float_slot = [] float_slot = []
for key in feed: for key in feed_i:
if key not in self.feed_names_: if key not in self.feed_names_:
continue continue
if self.feed_types_[key] == int_type: if self.feed_types_[key] == int_type:
...@@ -213,15 +202,11 @@ class Client(object): ...@@ -213,15 +202,11 @@ class Client(object):
elif self.feed_types_[key] == float_type: elif self.feed_types_[key] == float_type:
if counter == 0: if counter == 0:
float_feed_names.append(key) float_feed_names.append(key)
float_slot.append(feed[key]) float_slot.append(feed_i[key])
counter += 1 counter += 1
int_slot_batch.append(int_slot) int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot) float_slot_batch.append(float_slot)
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
result_batch = self.result_handle_ result_batch = self.result_handle_
res = self.client_handle_.batch_predict( res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册