提交 f5f06648 编写于 作者: M MRXLT

add shape check

上级 d2c6e6f8
......@@ -118,7 +118,8 @@ class Client(object):
self.producers = []
self.consumer = None
self.profile_ = _Profiler()
self.numpy_input = True
self.all_numpy_input = True
self.has_numpy_input = False
def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__)
......@@ -272,9 +273,10 @@ class Client(object):
if isinstance(feed_i[key], np.ndarray):
#int_slot.append(np.reshape(feed_i[key], (-1)).tolist())
int_slot.append(feed_i[key])
self.has_numpy_input = True
else:
int_slot.append(feed_i[key])
self.numpy_input = False
self.all_numpy_input = False
elif self.feed_types_[key] == float_type:
if i == 0:
float_feed_names.append(key)
......@@ -285,9 +287,10 @@ class Client(object):
if isinstance(feed_i[key], np.ndarray):
#float_slot.append(np.reshape(feed_i[key], (-1)).tolist())
float_slot.append(feed_i[key])
self.has_numpy_input = True
else:
float_slot.append(feed_i[key])
self.numpy_input = False
self.all_numpy_input = False
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
......@@ -295,14 +298,18 @@ class Client(object):
self.profile_.record('py_client_infer_0')
result_batch = self.result_handle_
if self.numpy_input:
if self.all_numpy_input:
res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid)
else:
elif self.has_numpy_input == False:
res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid)
else:
raise SystemExit(
"Please make sure the inputs are all in list type or all in numpy.array type"
)
self.profile_.record('py_client_infer_1')
self.profile_.record('py_postpro_0')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册