diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index d524201e45b0b948cbd3738fedb878e5f9b53e02..bea80f84bc9d29cabe4f31af612c694980b71d09 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -74,7 +74,7 @@ class Client(object): self.fetch_names_ = [] self.client_handle_ = None self.result_handle_ = None - self.feed_shapes_ = [] + self.feed_shapes_ = {} self.feed_types_ = {} self.feed_names_to_idx_ = {} self.rpath() @@ -106,13 +106,23 @@ class Client(object): 0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] - self.feed_shapes_ = [var.shape for var in model_conf.feed_var] self.feed_names_to_idx_ = {} self.fetch_names_to_type_ = {} self.fetch_names_to_idx_ = {} + self.lod_tensor_set = set() + self.feed_tensor_len = {} for i, var in enumerate(model_conf.feed_var): self.feed_names_to_idx_[var.alias_name] = i self.feed_types_[var.alias_name] = var.feed_type + self.feed_shapes_[var.alias_name] = var.shape + + if var.is_lod_tensor: + self.lod_tensor_set.add(var.alias_name) + else: + counter = 1 + for dim in self.feed_shapes_[var.alias_name]: + counter *= dim + self.feed_tensor_len[var.alias_name] = counter for i, var in enumerate(model_conf.fetch_var): self.fetch_names_to_idx_[var.alias_name] = i @@ -137,6 +147,14 @@ class Client(object): def get_fetch_names(self): return self.fetch_names_ + def shape_check(self, feed, key): + seq_shape = 1 + if key in self.lod_tensor_set: + return + if len(feed[key]) != self.feed_tensor_len[key]: + raise SystemExit("The shape of feed tensor {} not match.".format( + key)) + def predict(self, feed={}, fetch=[]): int_slot = [] float_slot = [] @@ -145,6 +163,7 @@ class Client(object): fetch_names = [] for key in feed: + self.shape_check(feed, key) if key not in self.feed_names_: continue if self.feed_types_[key] == int_type: