未验证 提交 0814608f 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #243 from MRXLT/general-server-fix

add shape check for input shape
......@@ -74,7 +74,7 @@ class Client(object):
self.fetch_names_ = []
self.client_handle_ = None
self.result_handle_ = None
self.feed_shapes_ = []
self.feed_shapes_ = {}
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
self.rpath()
......@@ -106,13 +106,23 @@ class Client(object):
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_shapes_ = [var.shape for var in model_conf.feed_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set()
self.feed_tensor_len = {}
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
self.feed_tensor_len[var.alias_name] = counter
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
......@@ -137,6 +147,14 @@ class Client(object):
def get_fetch_names(self):
return self.fetch_names_
def shape_check(self, feed, key):
seq_shape = 1
if key in self.lod_tensor_set:
return
if len(feed[key]) != self.feed_tensor_len[key]:
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
def predict(self, feed={}, fetch=[]):
int_slot = []
float_slot = []
......@@ -145,6 +163,7 @@ class Client(object):
fetch_names = []
for key in feed:
self.shape_check(feed, key)
if key not in self.feed_names_:
continue
if self.feed_types_[key] == int_type:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册