提交 6143f51a 编写于 作者: M MRXLT

improve shape check

上级 786ad6dc
......@@ -109,12 +109,19 @@ class Client(object):
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set()
self.feed_tensor_len = {}
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
self.feed_tensor_len[var.alias_name] = counter
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
......@@ -143,9 +150,7 @@ class Client(object):
seq_shape = 1
if key in self.lod_tensor_set:
return
for shape in self.feed_shapes_[key]:
seq_shape *= shape
if len(feed[key]) != seq_shape:
if len(feed[key]) != self.feed_tensor_len[key]:
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册