未验证 提交 2dd7db95 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #437 from MRXLT/server-robust

add feed check
...@@ -175,7 +175,6 @@ class Client(object): ...@@ -175,7 +175,6 @@ class Client(object):
return self.fetch_names_ return self.fetch_names_
def shape_check(self, feed, key): def shape_check(self, feed, key):
seq_shape = 1
if key in self.lod_tensor_set: if key in self.lod_tensor_set:
return return
if len(feed[key]) != self.feed_tensor_len[key]: if len(feed[key]) != self.feed_tensor_len[key]:
...@@ -192,7 +191,7 @@ class Client(object): ...@@ -192,7 +191,7 @@ class Client(object):
elif isinstance(fetch, list): elif isinstance(fetch, list):
fetch_list = fetch fetch_list = fetch
else: else:
raise ValueError("fetch only accepts string and list of string") raise ValueError("Fetch only accepts string and list of string")
feed_batch = [] feed_batch = []
if isinstance(feed, dict): if isinstance(feed, dict):
...@@ -200,7 +199,7 @@ class Client(object): ...@@ -200,7 +199,7 @@ class Client(object):
elif isinstance(feed, list): elif isinstance(feed, list):
feed_batch = feed feed_batch = feed
else: else:
raise ValueError("feed only accepts dict and list of dict") raise ValueError("Feed only accepts dict and list of dict")
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
...@@ -216,7 +215,7 @@ class Client(object): ...@@ -216,7 +215,7 @@ class Client(object):
if len(fetch_names) == 0: if len(fetch_names) == 0:
raise ValueError( raise ValueError(
"fetch names should not be empty or out of saved fetch list") "Fetch names should not be empty or out of saved fetch list.")
return {} return {}
for i, feed_i in enumerate(feed_batch): for i, feed_i in enumerate(feed_batch):
...@@ -224,7 +223,8 @@ class Client(object): ...@@ -224,7 +223,8 @@ class Client(object):
float_slot = [] float_slot = []
for key in feed_i: for key in feed_i:
if key not in self.feed_names_: if key not in self.feed_names_:
continue raise ValueError("Wrong feed name: {}.".format(key))
self.shape_check(feed_i, key)
if self.feed_types_[key] == int_type: if self.feed_types_[key] == int_type:
if i == 0: if i == 0:
int_feed_names.append(key) int_feed_names.append(key)
...@@ -233,6 +233,8 @@ class Client(object): ...@@ -233,6 +233,8 @@ class Client(object):
if i == 0: if i == 0:
float_feed_names.append(key) float_feed_names.append(key)
float_slot.append(feed_i[key]) float_slot.append(feed_i[key])
if len(int_slot) + len(float_slot) == 0:
raise ValueError("No feed data for predict.")
int_slot_batch.append(int_slot) int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot) float_slot_batch.append(float_slot)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册