diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 03fae970b8163be0a02478c9248637f4ce53fca9..b0bfcbeb993009320a57a3284b46f196d8800ac1 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -109,12 +109,19 @@ class Client(object): self.fetch_names_to_type_ = {} self.fetch_names_to_idx_ = {} self.lod_tensor_set = set() + self.feed_tensor_len = {} for i, var in enumerate(model_conf.feed_var): self.feed_names_to_idx_[var.alias_name] = i self.feed_types_[var.alias_name] = var.feed_type self.feed_shapes_[var.alias_name] = var.shape + if var.is_lod_tensor: self.lod_tensor_set.add(var.alias_name) + else: + counter = 1 + for dim in self.feed_shapes_[var.alias_name]: + counter *= dim + self.feed_tensor_len[var.alias_name] = counter for i, var in enumerate(model_conf.fetch_var): self.fetch_names_to_idx_[var.alias_name] = i @@ -143,9 +150,7 @@ class Client(object): seq_shape = 1 if key in self.lod_tensor_set: return - for shape in self.feed_shapes_[key]: - seq_shape *= shape - if len(feed[key]) != seq_shape: + if len(feed[key]) != self.feed_tensor_len[key]: raise SystemExit("The shape of feed tensor {} not match.".format( key))