提交 00c64d90 编写于 作者: W wangjiawei04

fix grpc impl, sync with brpc

上级 7ea59bda
...@@ -38,20 +38,9 @@ python test_asyn_client.py ...@@ -38,20 +38,9 @@ python test_asyn_client.py
python test_batch_client.py python test_batch_client.py
``` ```
### 通用 pb 预测
``` shell
python test_general_pb_client.py
```
### 预测超时 ### 预测超时
``` shell ``` shell
python test_timeout_client.py python test_timeout_client.py
``` ```
### List 输入
``` shell
python test_list_input_client.py
```
...@@ -18,7 +18,7 @@ import functools ...@@ -18,7 +18,7 @@ import functools
import time import time
import threading import threading
import grpc import grpc
import numpy as np
client = Client() client = Client()
client.connect(["127.0.0.1:9393"]) client.connect(["127.0.0.1:9393"])
...@@ -43,7 +43,8 @@ x = [ ...@@ -43,7 +43,8 @@ x = [
] ]
task_count = 0 task_count = 0
for i in range(3): for i in range(3):
future = client.predict(feed={"x": x}, fetch=["price"], asyn=True) new_data = np.array(x).astype("float32").reshape((1,13))
future = client.predict(feed={"x": new_data}, fetch=["price"], batch=False, asyn=True)
task_count += 1 task_count += 1
future.add_done_callback(functools.partial(call_back)) future.add_done_callback(functools.partial(call_back))
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client from paddle_serving_client import MultiLangClient as Client
import numpy as np
client = Client() client = Client()
client.connect(["127.0.0.1:9393"]) client.connect(["127.0.0.1:9393"])
...@@ -24,8 +24,11 @@ x = [ ...@@ -24,8 +24,11 @@ x = [
] ]
for i in range(3): for i in range(3):
batch_feed = [{"x": x} for j in range(batch_size)] new_data = np.array(x).astype("float32").reshape((1, 1, 13))
fetch_map = client.predict(feed=batch_feed, fetch=["price"]) batch_data = np.concatenate([new_data, new_data, new_data], axis=0)
print(batch_data.shape)
fetch_map = client.predict(feed={"x":batch_data}, fetch=["price"], batch=True)
if fetch_map["serving_status_code"] == 0: if fetch_map["serving_status_code"] == 0:
print(fetch_map) print(fetch_map)
else: else:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client
client = Client()
client.connect(["127.0.0.1:9393"])
x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332
]
for i in range(3):
fetch_map = client.predict(feed={"x": x}, fetch=["price"], is_python=False)
if fetch_map["serving_status_code"] == 0:
print(fetch_map)
else:
print(fetch_map["serving_status_code"])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client
import numpy as np
client = Client()
client.connect(["127.0.0.1:9393"])
x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332
]
for i in range(3):
fetch_map = client.predict(feed={"x": np.array(x)}, fetch=["price"])
if fetch_map["serving_status_code"] == 0:
print(fetch_map)
else:
print(fetch_map["serving_status_code"])
...@@ -14,16 +14,27 @@ ...@@ -14,16 +14,27 @@
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client from paddle_serving_client import MultiLangClient as Client
import numpy as np
client = Client() client = Client()
client.connect(["127.0.0.1:9393"]) client.connect(["127.0.0.1:9393"])
"""
for data in test_reader():
new_data = np.zeros((1, 1, 13)).astype("float32")
new_data[0] = data[0][0]
fetch_map = client.predict(
feed={"x": new_data}, fetch=["price"], batch=True)
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
print(fetch_map)
"""
x = [ x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332 0.4919, 0.1856, 0.0795, -0.0332
] ]
for i in range(3): for i in range(3):
fetch_map = client.predict(feed={"x": x}, fetch=["price"]) new_data = np.array(x).astype("float32").reshape((1,13))
fetch_map = client.predict(feed={"x": new_data}, fetch=["price"], batch=False)
if fetch_map["serving_status_code"] == 0: if fetch_map["serving_status_code"] == 0:
print(fetch_map) print(fetch_map)
else: else:
......
...@@ -15,17 +15,18 @@ ...@@ -15,17 +15,18 @@
from paddle_serving_client import MultiLangClient as Client from paddle_serving_client import MultiLangClient as Client
import grpc import grpc
import numpy as np
client = Client() client = Client()
client.connect(["127.0.0.1:9393"]) client.connect(["127.0.0.1:9393"])
client.set_rpc_timeout_ms(1) client.set_rpc_timeout_ms(40)
x = [ x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332 0.4919, 0.1856, 0.0795, -0.0332
] ]
for i in range(3): for i in range(3):
fetch_map = client.predict(feed={"x": x}, fetch=["price"]) new_data = np.array(x).astype("float32").reshape((1,13))
fetch_map = client.predict(feed={"x": new_data}, fetch=["price"], batch=False)
if fetch_map["serving_status_code"] == 0: if fetch_map["serving_status_code"] == 0:
print(fetch_map) print(fetch_map)
elif fetch_map["serving_status_code"] == grpc.StatusCode.DEADLINE_EXCEEDED: elif fetch_map["serving_status_code"] == grpc.StatusCode.DEADLINE_EXCEEDED:
......
...@@ -27,7 +27,7 @@ preprocess = Sequential([ ...@@ -27,7 +27,7 @@ preprocess = Sequential([
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client() client = Client()
client.connect(['127.0.0.1:9393']) client.connect(['127.0.0.1:9393'])
# client.set_rpc_timeout_ms(10000) client.set_rpc_timeout_ms(15000)
im = preprocess(sys.argv[1]) im = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
...@@ -35,7 +35,8 @@ fetch_map = client.predict( ...@@ -35,7 +35,8 @@ fetch_map = client.predict(
"image": im, "image": im,
"im_size": np.array(list(im.shape[1:])), "im_size": np.array(list(im.shape[1:])),
}, },
fetch=["save_infer_model/scale_0.tmp_0"]) fetch=["save_infer_model/scale_0.tmp_0"], batch=False)
print(fetch_map)
fetch_map.pop("serving_status_code") fetch_map.pop("serving_status_code")
fetch_map["image"] = sys.argv[1] fetch_map["image"] = sys.argv[1]
postprocess(fetch_map) postprocess(fetch_map)
...@@ -522,20 +522,15 @@ class MultiLangClient(object): ...@@ -522,20 +522,15 @@ class MultiLangClient(object):
req.fetch_var_names.extend(fetch) req.fetch_var_names.extend(fetch)
req.is_python = is_python req.is_python = is_python
req.log_id = log_id req.log_id = log_id
feed_batch = None feed_var_names = []
if isinstance(feed, dict): for key in feed.keys():
feed_batch = [feed] if '.lod' not in key:
elif isinstance(feed, list): feed_var_names.append(key)
feed_batch = feed req.feed_var_names.extend(feed_var_names)
else:
raise Exception("{} not support".format(type(feed)))
req.feed_var_names.extend(feed_batch[0].keys())
init_feed_names = False
for feed_data in feed_batch:
inst = multi_lang_general_model_service_pb2.FeedInst() inst = multi_lang_general_model_service_pb2.FeedInst()
for name in req.feed_var_names: for name in req.feed_var_names:
tensor = multi_lang_general_model_service_pb2.Tensor() tensor = multi_lang_general_model_service_pb2.Tensor()
var = feed_data[name] var = feed[name]
v_type = self.feed_types_[name] v_type = self.feed_types_[name]
if is_python: if is_python:
data = None data = None
...@@ -564,34 +559,9 @@ class MultiLangClient(object): ...@@ -564,34 +559,9 @@ class MultiLangClient(object):
else: else:
raise Exception("var must be list or ndarray.") raise Exception("var must be list or ndarray.")
tensor.data = data.tobytes() tensor.data = data.tobytes()
else:
if isinstance(var, np.ndarray):
if v_type == 0: # int64
tensor.int64_data.extend(
var.reshape(-1).astype("int64").tolist())
elif v_type == 1:
tensor.float_data.extend(
var.reshape(-1).astype('float32').tolist())
elif v_type == 2:
tensor.int_data.extend(
var.reshape(-1).astype('int32').tolist())
else:
raise Exception("error tensor value type.")
elif isinstance(var, list):
if v_type == 0:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1:
tensor.float_data.extend(self._flatten_list(var))
elif v_type == 2:
tensor.int_data.extend(self._flatten_list(var))
else:
raise Exception("error tensor value type.")
else:
raise Exception("var must be list or ndarray.")
if isinstance(var, np.ndarray):
tensor.shape.extend(list(var.shape)) tensor.shape.extend(list(var.shape))
else: if "{}.lod".format(name) in feed.keys():
tensor.shape.extend(self.feed_shapes_[name]) tensor.lod.extend(feed["{}.lod".format(name)])
inst.tensor_array.append(tensor) inst.tensor_array.append(tensor)
req.insts.append(inst) req.insts.append(inst)
return req return req
...@@ -652,10 +622,17 @@ class MultiLangClient(object): ...@@ -652,10 +622,17 @@ class MultiLangClient(object):
def predict(self, def predict(self,
feed, feed,
fetch, fetch,
batch=True,
need_variant_tag=False, need_variant_tag=False,
asyn=False, asyn=False,
is_python=True, is_python=True,
log_id=0): log_id=0):
if isinstance(feed, dict) is False:
raise ValueError("Type Error. grpc feed must be dict.")
if batch is False:
for key in feed:
if ".lod" not in key:
feed[key] = feed[key][np.newaxis, :]
if not asyn: if not asyn:
try: try:
self.profile_.record('py_prepro_0') self.profile_.record('py_prepro_0')
......
...@@ -523,9 +523,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -523,9 +523,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
fetch_names = list(request.fetch_var_names) fetch_names = list(request.fetch_var_names)
is_python = request.is_python is_python = request.is_python
log_id = request.log_id log_id = request.log_id
feed_batch = []
for feed_inst in request.insts:
feed_dict = {} feed_dict = {}
feed_inst = request.insts[0]
for idx, name in enumerate(feed_names): for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx] var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name] v_type = self.feed_types_[name]
...@@ -539,19 +538,11 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -539,19 +538,11 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data = np.frombuffer(var.data, dtype="int32") data = np.frombuffer(var.data, dtype="int32")
else: else:
raise Exception("error type.") raise Exception("error type.")
else:
if v_type == 0: # int64
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2: # int32
data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape) data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data feed_dict[name] = data
feed_batch.append(feed_dict) if len(var.lod) > 0:
return feed_batch, fetch_names, is_python, log_id feed_dict["{}.lod".format()] = var.lod
return feed_dict, fetch_names, is_python, log_id
def _pack_inference_response(self, ret, fetch_names, is_python): def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.InferenceResponse() resp = multi_lang_general_model_service_pb2.InferenceResponse()
...@@ -608,6 +599,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -608,6 +599,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
ret = self.bclient_.predict( ret = self.bclient_.predict(
feed=feed_dict, feed=feed_dict,
fetch=fetch_names, fetch=fetch_names,
batch=True,
need_variant_tag=True, need_variant_tag=True,
log_id=log_id) log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python) return self._pack_inference_response(ret, fetch_names, is_python)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册