From f41375083034ac4e777b7c19591280b8431b54ab Mon Sep 17 00:00:00 2001 From: guru4elephant Date: Sun, 22 Mar 2020 07:57:28 +0800 Subject: [PATCH] unify predict and batch predict interface and call batch_predict inside the predict function --- python/paddle_serving_client/__init__.py | 73 ++++++++++-------------- 1 file changed, 29 insertions(+), 44 deletions(-) diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index ce0eb8c8..d9a4a3ff 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -155,44 +155,23 @@ class Client(object): raise SystemExit("The shape of feed tensor {} not match.".format( key)) - def predict(self, feed={}, fetch=[]): - int_slot = [] - float_slot = [] - int_feed_names = [] - float_feed_names = [] - fetch_names = [] - - for key in feed: - self.shape_check(feed, key) - if key not in self.feed_names_: - continue - if self.feed_types_[key] == int_type: - int_feed_names.append(key) - int_slot.append(feed[key]) - elif self.feed_types_[key] == float_type: - float_feed_names.append(key) - float_slot.append(feed[key]) - - for key in fetch: - if key in self.fetch_names_: - fetch_names.append(key) + def predict(self, feed=None, fetch=None): + fetch_list = [] + if isinstance(fetch, str): + fetch_list = [fetch] + elif isinstance(fetch, list): + fetch_list = fetch + else: + raise ValueError("fetch only accepts string and list of string") + + feed_batch = [] + if isinstance(feed, dict): + feed_batch.append(feed) + elif isinstance(feed, list): + feed_batch = feed + else: + raise ValueError("feed only accepts dict and list of dict") - ret = self.client_handle_.predict(float_slot, float_feed_names, - int_slot, int_feed_names, fetch_names, - self.result_handle_, self.pid) - - result_map = {} - for i, name in enumerate(fetch_names): - if self.fetch_names_to_type_[name] == int_type: - result_map[name] = self.result_handle_.get_int64_by_name(name)[ - 0] - elif self.fetch_names_to_type_[name] == float_type: - result_map[name] = self.result_handle_.get_float_by_name(name)[ - 0] - - return result_map - - def batch_predict(self, feed_batch=[], fetch=[]): int_slot_batch = [] float_slot_batch = [] int_feed_names = [] @@ -200,10 +179,20 @@ class Client(object): fetch_names = [] counter = 0 batch_size = len(feed_batch) - for feed in feed_batch: + + for key in fetch_list: + if key in self.fetch_names_: + fetch_names.append(key) + + if len(fetch_names) == 0: + raise ValueError( + "fetch names should not be empty or out of saved fetch list") + return {} + + for feed_i in feed_batch: int_slot = [] float_slot = [] - for key in feed: + for key in feed_i: if key not in self.feed_names_: continue if self.feed_types_[key] == int_type: @@ -213,15 +202,11 @@ class Client(object): elif self.feed_types_[key] == float_type: if counter == 0: float_feed_names.append(key) - float_slot.append(feed[key]) + float_slot.append(feed_i[key]) counter += 1 int_slot_batch.append(int_slot) float_slot_batch.append(float_slot) - for key in fetch: - if key in self.fetch_names_: - fetch_names.append(key) - result_batch = self.result_handle_ res = self.client_handle_.batch_predict( float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, -- GitLab