提交 fbea23f1 编写于 作者: M MRXLT

add feed check

上级 cb3dde17
......@@ -175,7 +175,6 @@ class Client(object):
return self.fetch_names_
def shape_check(self, feed, key):
seq_shape = 1
if key in self.lod_tensor_set:
return
if len(feed[key]) != self.feed_tensor_len[key]:
......@@ -192,7 +191,7 @@ class Client(object):
elif isinstance(fetch, list):
fetch_list = fetch
else:
raise ValueError("fetch only accepts string and list of string")
raise ValueError("Fetch only accepts string and list of string")
feed_batch = []
if isinstance(feed, dict):
......@@ -200,7 +199,7 @@ class Client(object):
elif isinstance(feed, list):
feed_batch = feed
else:
raise ValueError("feed only accepts dict and list of dict")
raise ValueError("Feed only accepts dict and list of dict")
int_slot_batch = []
float_slot_batch = []
......@@ -216,7 +215,7 @@ class Client(object):
if len(fetch_names) == 0:
raise ValueError(
"fetch names should not be empty or out of saved fetch list")
"Fetch names should not be empty or out of saved fetch list.")
return {}
for i, feed_i in enumerate(feed_batch):
......@@ -224,7 +223,8 @@ class Client(object):
float_slot = []
for key in feed_i:
if key not in self.feed_names_:
continue
raise ValueError("Wrong feed name: {}.".format(key))
self.shape_check(feed_i, key)
if self.feed_types_[key] == int_type:
if i == 0:
int_feed_names.append(key)
......@@ -233,6 +233,8 @@ class Client(object):
if i == 0:
float_feed_names.append(key)
float_slot.append(feed_i[key])
if len(int_slot) + len(float_slot) == 0:
raise ValueError("No feed data for predict.")
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册