提交 f4137508 编写于 作者: G guru4elephant 提交者: dongdaxiang

unify predict and batch predict interface and call batch_predict inside the predict function

上级 bafb844b
......@@ -155,44 +155,23 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
def predict(self, feed={}, fetch=[]):
int_slot = []
float_slot = []
int_feed_names = []
float_feed_names = []
fetch_names = []
for key in feed:
self.shape_check(feed, key)
if key not in self.feed_names_:
continue
if self.feed_types_[key] == int_type:
int_feed_names.append(key)
int_slot.append(feed[key])
elif self.feed_types_[key] == float_type:
float_feed_names.append(key)
float_slot.append(feed[key])
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
ret = self.client_handle_.predict(float_slot, float_feed_names,
int_slot, int_feed_names, fetch_names,
self.result_handle_, self.pid)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[
0]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
def predict(self, feed=None, fetch=None):
fetch_list = []
if isinstance(fetch, str):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
else:
raise ValueError("fetch only accepts string and list of string")
return result_map
feed_batch = []
if isinstance(feed, dict):
feed_batch.append(feed)
elif isinstance(feed, list):
feed_batch = feed
else:
raise ValueError("feed only accepts dict and list of dict")
def batch_predict(self, feed_batch=[], fetch=[]):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -200,10 +179,20 @@ class Client(object):
fetch_names = []
counter = 0
batch_size = len(feed_batch)
for feed in feed_batch:
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"fetch names should not be empty or out of saved fetch list")
return {}
for feed_i in feed_batch:
int_slot = []
float_slot = []
for key in feed:
for key in feed_i:
if key not in self.feed_names_:
continue
if self.feed_types_[key] == int_type:
......@@ -213,15 +202,11 @@ class Client(object):
elif self.feed_types_[key] == float_type:
if counter == 0:
float_feed_names.append(key)
float_slot.append(feed[key])
float_slot.append(feed_i[key])
counter += 1
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
result_batch = self.result_handle_
res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册