提交 00c64d90 编写于 作者: W wangjiawei04

fix grpc impl, sync with brpc

上级 7ea59bda
......@@ -38,20 +38,9 @@ python test_asyn_client.py
python test_batch_client.py
```
### 通用 pb 预测
``` shell
python test_general_pb_client.py
```
### 预测超时
``` shell
python test_timeout_client.py
```
### List 输入
``` shell
python test_list_input_client.py
```
......@@ -18,7 +18,7 @@ import functools
import time
import threading
import grpc
import numpy as np
client = Client()
client.connect(["127.0.0.1:9393"])
......@@ -43,7 +43,8 @@ x = [
]
task_count = 0
for i in range(3):
future = client.predict(feed={"x": x}, fetch=["price"], asyn=True)
new_data = np.array(x).astype("float32").reshape((1,13))
future = client.predict(feed={"x": new_data}, fetch=["price"], batch=False, asyn=True)
task_count += 1
future.add_done_callback(functools.partial(call_back))
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client
import numpy as np
client = Client()
client.connect(["127.0.0.1:9393"])
......@@ -24,8 +24,11 @@ x = [
]
for i in range(3):
batch_feed = [{"x": x} for j in range(batch_size)]
fetch_map = client.predict(feed=batch_feed, fetch=["price"])
new_data = np.array(x).astype("float32").reshape((1, 1, 13))
batch_data = np.concatenate([new_data, new_data, new_data], axis=0)
print(batch_data.shape)
fetch_map = client.predict(feed={"x":batch_data}, fetch=["price"], batch=True)
if fetch_map["serving_status_code"] == 0:
print(fetch_map)
else:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client
client = Client()
client.connect(["127.0.0.1:9393"])
x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332
]
for i in range(3):
fetch_map = client.predict(feed={"x": x}, fetch=["price"], is_python=False)
if fetch_map["serving_status_code"] == 0:
print(fetch_map)
else:
print(fetch_map["serving_status_code"])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client
import numpy as np
client = Client()
client.connect(["127.0.0.1:9393"])
x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332
]
for i in range(3):
fetch_map = client.predict(feed={"x": np.array(x)}, fetch=["price"])
if fetch_map["serving_status_code"] == 0:
print(fetch_map)
else:
print(fetch_map["serving_status_code"])
......@@ -14,16 +14,27 @@
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient as Client
import numpy as np
client = Client()
client.connect(["127.0.0.1:9393"])
"""
for data in test_reader():
new_data = np.zeros((1, 1, 13)).astype("float32")
new_data[0] = data[0][0]
fetch_map = client.predict(
feed={"x": new_data}, fetch=["price"], batch=True)
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
print(fetch_map)
"""
x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332
]
for i in range(3):
fetch_map = client.predict(feed={"x": x}, fetch=["price"])
new_data = np.array(x).astype("float32").reshape((1,13))
fetch_map = client.predict(feed={"x": new_data}, fetch=["price"], batch=False)
if fetch_map["serving_status_code"] == 0:
print(fetch_map)
else:
......
......@@ -15,17 +15,18 @@
from paddle_serving_client import MultiLangClient as Client
import grpc
import numpy as np
client = Client()
client.connect(["127.0.0.1:9393"])
client.set_rpc_timeout_ms(1)
client.set_rpc_timeout_ms(40)
x = [
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283,
0.4919, 0.1856, 0.0795, -0.0332
]
for i in range(3):
fetch_map = client.predict(feed={"x": x}, fetch=["price"])
new_data = np.array(x).astype("float32").reshape((1,13))
fetch_map = client.predict(feed={"x": new_data}, fetch=["price"], batch=False)
if fetch_map["serving_status_code"] == 0:
print(fetch_map)
elif fetch_map["serving_status_code"] == grpc.StatusCode.DEADLINE_EXCEEDED:
......
......@@ -27,7 +27,7 @@ preprocess = Sequential([
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client()
client.connect(['127.0.0.1:9393'])
# client.set_rpc_timeout_ms(10000)
client.set_rpc_timeout_ms(15000)
im = preprocess(sys.argv[1])
fetch_map = client.predict(
......@@ -35,7 +35,8 @@ fetch_map = client.predict(
"image": im,
"im_size": np.array(list(im.shape[1:])),
},
fetch=["save_infer_model/scale_0.tmp_0"])
fetch=["save_infer_model/scale_0.tmp_0"], batch=False)
print(fetch_map)
fetch_map.pop("serving_status_code")
fetch_map["image"] = sys.argv[1]
postprocess(fetch_map)
......@@ -522,20 +522,15 @@ class MultiLangClient(object):
req.fetch_var_names.extend(fetch)
req.is_python = is_python
req.log_id = log_id
feed_batch = None
if isinstance(feed, dict):
feed_batch = [feed]
elif isinstance(feed, list):
feed_batch = feed
else:
raise Exception("{} not support".format(type(feed)))
req.feed_var_names.extend(feed_batch[0].keys())
init_feed_names = False
for feed_data in feed_batch:
feed_var_names = []
for key in feed.keys():
if '.lod' not in key:
feed_var_names.append(key)
req.feed_var_names.extend(feed_var_names)
inst = multi_lang_general_model_service_pb2.FeedInst()
for name in req.feed_var_names:
tensor = multi_lang_general_model_service_pb2.Tensor()
var = feed_data[name]
var = feed[name]
v_type = self.feed_types_[name]
if is_python:
data = None
......@@ -564,34 +559,9 @@ class MultiLangClient(object):
else:
raise Exception("var must be list or ndarray.")
tensor.data = data.tobytes()
else:
if isinstance(var, np.ndarray):
if v_type == 0: # int64
tensor.int64_data.extend(
var.reshape(-1).astype("int64").tolist())
elif v_type == 1:
tensor.float_data.extend(
var.reshape(-1).astype('float32').tolist())
elif v_type == 2:
tensor.int_data.extend(
var.reshape(-1).astype('int32').tolist())
else:
raise Exception("error tensor value type.")
elif isinstance(var, list):
if v_type == 0:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1:
tensor.float_data.extend(self._flatten_list(var))
elif v_type == 2:
tensor.int_data.extend(self._flatten_list(var))
else:
raise Exception("error tensor value type.")
else:
raise Exception("var must be list or ndarray.")
if isinstance(var, np.ndarray):
tensor.shape.extend(list(var.shape))
else:
tensor.shape.extend(self.feed_shapes_[name])
if "{}.lod".format(name) in feed.keys():
tensor.lod.extend(feed["{}.lod".format(name)])
inst.tensor_array.append(tensor)
req.insts.append(inst)
return req
......@@ -652,10 +622,17 @@ class MultiLangClient(object):
def predict(self,
feed,
fetch,
batch=True,
need_variant_tag=False,
asyn=False,
is_python=True,
log_id=0):
if isinstance(feed, dict) is False:
raise ValueError("Type Error. grpc feed must be dict.")
if batch is False:
for key in feed:
if ".lod" not in key:
feed[key] = feed[key][np.newaxis, :]
if not asyn:
try:
self.profile_.record('py_prepro_0')
......
......@@ -523,9 +523,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
log_id = request.log_id
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
feed_inst = request.insts[0]
for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name]
......@@ -539,19 +538,11 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
else:
if v_type == 0: # int64
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2: # int32
data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python, log_id
if len(var.lod) > 0:
feed_dict["{}.lod".format()] = var.lod
return feed_dict, fetch_names, is_python, log_id
def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.InferenceResponse()
......@@ -608,6 +599,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
ret = self.bclient_.predict(
feed=feed_dict,
fetch=fetch_names,
batch=True,
need_variant_tag=True,
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册