diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index b150e9cca646820c715bcad6a2cb1765a0df25f7..958d9841fc9586e2806e34da5d4ec7eaaf9d8818 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -118,7 +118,8 @@ class Client(object): self.producers = [] self.consumer = None self.profile_ = _Profiler() - self.numpy_input = True + self.all_numpy_input = True + self.has_numpy_input = False def rpath(self): lib_path = os.path.dirname(paddle_serving_client.__file__) @@ -272,9 +273,10 @@ class Client(object): if isinstance(feed_i[key], np.ndarray): #int_slot.append(np.reshape(feed_i[key], (-1)).tolist()) int_slot.append(feed_i[key]) + self.has_numpy_input = True else: int_slot.append(feed_i[key]) - self.numpy_input = False + self.all_numpy_input = False elif self.feed_types_[key] == float_type: if i == 0: float_feed_names.append(key) @@ -285,9 +287,10 @@ class Client(object): if isinstance(feed_i[key], np.ndarray): #float_slot.append(np.reshape(feed_i[key], (-1)).tolist()) float_slot.append(feed_i[key]) + self.has_numpy_input = True else: float_slot.append(feed_i[key]) - self.numpy_input = False + self.all_numpy_input = False int_slot_batch.append(int_slot) float_slot_batch.append(float_slot) @@ -295,14 +298,18 @@ class Client(object): self.profile_.record('py_client_infer_0') result_batch = self.result_handle_ - if self.numpy_input: + if self.all_numpy_input: res = self.client_handle_.numpy_predict( float_slot_batch, float_feed_names, float_shape, int_slot_batch, int_feed_names, int_shape, fetch_names, result_batch, self.pid) - else: + elif self.has_numpy_input == False: res = self.client_handle_.batch_predict( float_slot_batch, float_feed_names, float_shape, int_slot_batch, int_feed_names, int_shape, fetch_names, result_batch, self.pid) + else: + raise SystemExit( + "Please make sure the inputs are all in list type or all in numpy.array type" + ) self.profile_.record('py_client_infer_1') self.profile_.record('py_postpro_0')