提交 f5f06648 编写于 作者: M MRXLT

add shape check

上级 d2c6e6f8
...@@ -118,7 +118,8 @@ class Client(object): ...@@ -118,7 +118,8 @@ class Client(object):
self.producers = [] self.producers = []
self.consumer = None self.consumer = None
self.profile_ = _Profiler() self.profile_ = _Profiler()
self.numpy_input = True self.all_numpy_input = True
self.has_numpy_input = False
def rpath(self): def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__) lib_path = os.path.dirname(paddle_serving_client.__file__)
...@@ -272,9 +273,10 @@ class Client(object): ...@@ -272,9 +273,10 @@ class Client(object):
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
#int_slot.append(np.reshape(feed_i[key], (-1)).tolist()) #int_slot.append(np.reshape(feed_i[key], (-1)).tolist())
int_slot.append(feed_i[key]) int_slot.append(feed_i[key])
self.has_numpy_input = True
else: else:
int_slot.append(feed_i[key]) int_slot.append(feed_i[key])
self.numpy_input = False self.all_numpy_input = False
elif self.feed_types_[key] == float_type: elif self.feed_types_[key] == float_type:
if i == 0: if i == 0:
float_feed_names.append(key) float_feed_names.append(key)
...@@ -285,9 +287,10 @@ class Client(object): ...@@ -285,9 +287,10 @@ class Client(object):
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
#float_slot.append(np.reshape(feed_i[key], (-1)).tolist()) #float_slot.append(np.reshape(feed_i[key], (-1)).tolist())
float_slot.append(feed_i[key]) float_slot.append(feed_i[key])
self.has_numpy_input = True
else: else:
float_slot.append(feed_i[key]) float_slot.append(feed_i[key])
self.numpy_input = False self.all_numpy_input = False
int_slot_batch.append(int_slot) int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot) float_slot_batch.append(float_slot)
...@@ -295,14 +298,18 @@ class Client(object): ...@@ -295,14 +298,18 @@ class Client(object):
self.profile_.record('py_client_infer_0') self.profile_.record('py_client_infer_0')
result_batch = self.result_handle_ result_batch = self.result_handle_
if self.numpy_input: if self.all_numpy_input:
res = self.client_handle_.numpy_predict( res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid) int_feed_names, int_shape, fetch_names, result_batch, self.pid)
else: elif self.has_numpy_input == False:
res = self.client_handle_.batch_predict( res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid) int_feed_names, int_shape, fetch_names, result_batch, self.pid)
else:
raise SystemExit(
"Please make sure the inputs are all in list type or all in numpy.array type"
)
self.profile_.record('py_client_infer_1') self.profile_.record('py_client_infer_1')
self.profile_.record('py_postpro_0') self.profile_.record('py_postpro_0')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册