提交 786ad6dc 编写于 作者: M MRXLT

add shape check for input shape

上级 0783866b
......@@ -74,7 +74,7 @@ class Client(object):
self.fetch_names_ = []
self.client_handle_ = None
self.result_handle_ = None
self.feed_shapes_ = []
self.feed_shapes_ = {}
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
self.rpath()
......@@ -85,7 +85,6 @@ class Client(object):
lib_path = os.path.join(lib_path, 'lib')
os.popen('patchelf --set-rpath {} {}'.format(lib_path, client_path))
def load_client_config(self, path):
from .serving_client import PredictorClient
from .serving_client import PredictorRes
......@@ -106,13 +105,16 @@ class Client(object):
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_shapes_ = [var.shape for var in model_conf.feed_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
......@@ -128,9 +130,8 @@ class Client(object):
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
print(sdk_desc)
self.client_handle_.create_predictor_by_desc(
sdk_desc.SerializeToString())
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
def get_feed_names(self):
return self.feed_names_
......@@ -138,6 +139,16 @@ class Client(object):
def get_fetch_names(self):
return self.fetch_names_
def shape_check(self, feed, key):
seq_shape = 1
if key in self.lod_tensor_set:
return
for shape in self.feed_shapes_[key]:
seq_shape *= shape
if len(feed[key]) != seq_shape:
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
def predict(self, feed={}, fetch=[]):
int_slot = []
float_slot = []
......@@ -145,6 +156,7 @@ class Client(object):
float_feed_names = []
fetch_names = []
for key in feed:
self.shape_check(feed, key)
if key not in self.feed_names_:
continue
if self.feed_types_[key] == int_type:
......@@ -158,16 +170,18 @@ class Client(object):
if key in self.fetch_names_:
fetch_names.append(key)
ret = self.client_handle_.predict(
float_slot, float_feed_names, int_slot,
int_feed_names, fetch_names, self.result_handle_)
ret = self.client_handle_.predict(float_slot, float_feed_names,
int_slot, int_feed_names, fetch_names,
self.result_handle_)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[0]
result_map[name] = self.result_handle_.get_int64_by_name(name)[
0]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = self.result_handle_.get_float_by_name(name)[0]
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
return result_map
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册