diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index f86dab617cc59a400bc915b2d497f112335f3bab..03fae970b8163be0a02478c9248637f4ce53fca9 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -74,7 +74,7 @@ class Client(object): self.fetch_names_ = [] self.client_handle_ = None self.result_handle_ = None - self.feed_shapes_ = [] + self.feed_shapes_ = {} self.feed_types_ = {} self.feed_names_to_idx_ = {} self.rpath() @@ -85,7 +85,6 @@ class Client(object): lib_path = os.path.join(lib_path, 'lib') os.popen('patchelf --set-rpath {} {}'.format(lib_path, client_path)) - def load_client_config(self, path): from .serving_client import PredictorClient from .serving_client import PredictorRes @@ -106,13 +105,16 @@ class Client(object): 0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] - self.feed_shapes_ = [var.shape for var in model_conf.feed_var] self.feed_names_to_idx_ = {} self.fetch_names_to_type_ = {} self.fetch_names_to_idx_ = {} + self.lod_tensor_set = set() for i, var in enumerate(model_conf.feed_var): self.feed_names_to_idx_[var.alias_name] = i self.feed_types_[var.alias_name] = var.feed_type + self.feed_shapes_[var.alias_name] = var.shape + if var.is_lod_tensor: + self.lod_tensor_set.add(var.alias_name) for i, var in enumerate(model_conf.fetch_var): self.fetch_names_to_idx_[var.alias_name] = i @@ -128,9 +130,8 @@ class Client(object): predictor_sdk.set_server_endpoints(endpoints) sdk_desc = predictor_sdk.gen_desc() print(sdk_desc) - self.client_handle_.create_predictor_by_desc( - sdk_desc.SerializeToString()) - + self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( + )) def get_feed_names(self): return self.feed_names_ @@ -138,6 +139,16 @@ class Client(object): def get_fetch_names(self): return self.fetch_names_ + def shape_check(self, feed, key): + seq_shape = 1 + if key in self.lod_tensor_set: + return + for shape in self.feed_shapes_[key]: + seq_shape *= shape + if len(feed[key]) != seq_shape: + raise SystemExit("The shape of feed tensor {} not match.".format( + key)) + def predict(self, feed={}, fetch=[]): int_slot = [] float_slot = [] @@ -145,6 +156,7 @@ class Client(object): float_feed_names = [] fetch_names = [] for key in feed: + self.shape_check(feed, key) if key not in self.feed_names_: continue if self.feed_types_[key] == int_type: @@ -158,16 +170,18 @@ class Client(object): if key in self.fetch_names_: fetch_names.append(key) - ret = self.client_handle_.predict( - float_slot, float_feed_names, int_slot, - int_feed_names, fetch_names, self.result_handle_) + ret = self.client_handle_.predict(float_slot, float_feed_names, + int_slot, int_feed_names, fetch_names, + self.result_handle_) result_map = {} for i, name in enumerate(fetch_names): if self.fetch_names_to_type_[name] == int_type: - result_map[name] = self.result_handle_.get_int64_by_name(name)[0] + result_map[name] = self.result_handle_.get_int64_by_name(name)[ + 0] elif self.fetch_names_to_type_[name] == float_type: - result_map[name] = self.result_handle_.get_float_by_name(name)[0] + result_map[name] = self.result_handle_.get_float_by_name(name)[ + 0] return result_map