提交 c74d74ea 编写于 作者: M MRXLT

fix shape check for numpy array

上级 fe537781
......@@ -203,7 +203,12 @@ class Client(object):
def shape_check(self, feed, key):
if key in self.lod_tensor_set:
return
if len(feed[key]) != self.feed_tensor_len[key]:
if isinstance(feed[key],
list) and len(feed[key]) != self.feed_tensor_len[key]:
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
if type(feed[key]).__module__ == np.__name__ and np.size(feed[
key]) != self.feed_tensor_len[key]:
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册