提交 d6e8cf19 编写于 作者: B barrierye

add PyClient

上级 02f9966c
......@@ -11,41 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
from pyclient import PyClient
import numpy as np
channel = grpc.insecure_channel('localhost:8080')
stub = general_python_service_pb2_grpc.GeneralPythonServiceStub(channel)
req = general_python_service_pb2.Request()
"""
# line = "i am very sad | 0"
word_ids = np.array([8, 233, 52, 601], dtype='int64')
# word_ids = np.array([8, 233, 52, 601])
print(word_ids)
data = np.ndarray.tobytes(word_ids)
print(data)
# xx = np.frombuffer(data)
xx = np.frombuffer(data, dtype='int64')
print (xx)
req.feed_var_names.append("words")
req.feed_insts.append(data)
"""
client = PyClient()
client.connect('localhost:8080')
x = np.array(
[
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584,
0.6283, 0.4919, 0.1856, 0.0795, -0.0332
],
dtype='float')
data = np.ndarray.tobytes(x)
req.feed_var_names.append("x")
req.feed_insts.append(data)
for i in range(100):
resp = stub.inference(req)
for idx, name in enumerate(resp.fetch_var_names):
print('{}: {}'.format(
name, np.frombuffer(
resp.fetch_insts[idx], dtype='float')))
for i in range(5):
fetch_map = client.predict(
feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
print(fetch_map)
......@@ -37,7 +37,7 @@ class CombineOp(Op):
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
inst.data = np.ndarray.tobytes(cnt)
inst.name = "resp"
inst.name = "combine_op_output"
data.insts.append(inst)
return data
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
service GeneralPythonService {
rpc inference(Request) returns (Response) {}
}
message Request {
repeated bytes feed_insts = 1;
repeated string feed_var_names = 2;
}
message Response {
repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2;
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
import numpy as np
class PyClient(object):
def __init__(self):
self._channel = None
def connect(self, endpoint):
self._channel = grpc.insecure_channel(endpoint)
self._stub = general_python_service_pb2_grpc.GeneralPythonServiceStub(
self._channel)
def _pack_data_for_infer(self, feed_data):
req = general_python_service_pb2.Request()
for name, data in feed_data.items():
if not isinstance(data, np.ndarray):
raise TypeError(
"only numpy array type is supported temporarily.")
data2bytes = np.ndarray.tobytes(data)
req.feed_var_names.append(name)
req.feed_insts.append(data2bytes)
return req
def predict(self, feed, fetch_with_type):
if not isinstance(feed, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch_with_type, dict):
raise TypeError(
"fetch_with_type must be dict type with format: {name : type}.")
req = self._pack_data_for_infer(feed)
resp = self._stub.inference(req)
fetch_map = {}
for idx, name in enumerate(resp.fetch_var_names):
if name not in fetch_with_type:
continue
fetch_map[name] = np.frombuffer(
resp.fetch_insts[idx], dtype=fetch_with_type[name])
return fetch_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册