diff --git a/python/examples/imdb/test_py_client.py b/python/examples/imdb/test_py_client.py index a4a2c3194089cc94edae8b95fd823d45a9776a01..9c3b314133370d5a4d503b11e27f81bc20b43509 100644 --- a/python/examples/imdb/test_py_client.py +++ b/python/examples/imdb/test_py_client.py @@ -11,41 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=doc-string-missing -import grpc -import general_python_service_pb2 -import general_python_service_pb2_grpc +from pyclient import PyClient import numpy as np -channel = grpc.insecure_channel('localhost:8080') -stub = general_python_service_pb2_grpc.GeneralPythonServiceStub(channel) -req = general_python_service_pb2.Request() -""" -# line = "i am very sad | 0" -word_ids = np.array([8, 233, 52, 601], dtype='int64') -# word_ids = np.array([8, 233, 52, 601]) -print(word_ids) -data = np.ndarray.tobytes(word_ids) -print(data) -# xx = np.frombuffer(data) -xx = np.frombuffer(data, dtype='int64') -print (xx) -req.feed_var_names.append("words") -req.feed_insts.append(data) -""" +client = PyClient() +client.connect('localhost:8080') + x = np.array( [ 0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332 ], dtype='float') -data = np.ndarray.tobytes(x) -req.feed_var_names.append("x") -req.feed_insts.append(data) -for i in range(100): - resp = stub.inference(req) - for idx, name in enumerate(resp.fetch_var_names): - print('{}: {}'.format( - name, np.frombuffer( - resp.fetch_insts[idx], dtype='float'))) +for i in range(5): + fetch_map = client.predict( + feed={"x": x}, fetch_with_type={"combine_op_output": "float"}) + print(fetch_map) diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index fa1fa0d763497e05e4f21a11a5f4e29c0b3b1869..dee03d564c87627ee6b22a3d64c2fa1527be9d74 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -37,7 +37,7 @@ class CombineOp(Op): data = python_service_channel_pb2.ChannelData() inst = python_service_channel_pb2.Inst() inst.data = np.ndarray.tobytes(cnt) - inst.name = "resp" + inst.name = "combine_op_output" data.insts.append(inst) return data diff --git a/python/paddle_serving_client/general_python_service.proto b/python/paddle_serving_client/general_python_service.proto new file mode 100644 index 0000000000000000000000000000000000000000..5613558663f3f61b3ad8e9f1aed35cda6afa98d3 --- /dev/null +++ b/python/paddle_serving_client/general_python_service.proto @@ -0,0 +1,29 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +service GeneralPythonService { + rpc inference(Request) returns (Response) {} +} + +message Request { + repeated bytes feed_insts = 1; + repeated string feed_var_names = 2; +} + +message Response { + repeated bytes fetch_insts = 1; + repeated string fetch_var_names = 2; +} diff --git a/python/paddle_serving_client/pyclient.py b/python/paddle_serving_client/pyclient.py new file mode 100644 index 0000000000000000000000000000000000000000..29df85f045210d49703cc07c720a66f2b81697c0 --- /dev/null +++ b/python/paddle_serving_client/pyclient.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing +import grpc +import general_python_service_pb2 +import general_python_service_pb2_grpc +import numpy as np + + +class PyClient(object): + def __init__(self): + self._channel = None + + def connect(self, endpoint): + self._channel = grpc.insecure_channel(endpoint) + self._stub = general_python_service_pb2_grpc.GeneralPythonServiceStub( + self._channel) + + def _pack_data_for_infer(self, feed_data): + req = general_python_service_pb2.Request() + for name, data in feed_data.items(): + if not isinstance(data, np.ndarray): + raise TypeError( + "only numpy array type is supported temporarily.") + data2bytes = np.ndarray.tobytes(data) + req.feed_var_names.append(name) + req.feed_insts.append(data2bytes) + return req + + def predict(self, feed, fetch_with_type): + if not isinstance(feed, dict): + raise TypeError( + "feed must be dict type with format: {name: value}.") + if not isinstance(fetch_with_type, dict): + raise TypeError( + "fetch_with_type must be dict type with format: {name : type}.") + req = self._pack_data_for_infer(feed) + resp = self._stub.inference(req) + fetch_map = {} + for idx, name in enumerate(resp.fetch_var_names): + if name not in fetch_with_type: + continue + fetch_map[name] = np.frombuffer( + resp.fetch_insts[idx], dtype=fetch_with_type[name]) + return fetch_map