From 92fb727f353ac7632b94dc4d2b5057e6327c586a Mon Sep 17 00:00:00 2001 From: barrierye Date: Sun, 28 Jun 2020 03:27:55 +0800 Subject: [PATCH] add client --- python/pipeline/pipeline_client.py | 55 ++++++++++++++++++++++++++++++ python/pipeline/proto/__init__.py | 13 +++++++ python/pipeline/util.py | 25 ++++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 python/pipeline/pipeline_client.py create mode 100644 python/pipeline/proto/__init__.py create mode 100644 python/pipeline/util.py diff --git a/python/pipeline/pipeline_client.py b/python/pipeline/pipeline_client.py new file mode 100644 index 00000000..06b69443 --- /dev/null +++ b/python/pipeline/pipeline_client.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing +import grpc +import numpy as np +from .proto import pipeline_service_pb2 +from .proto import pipeline_service_pb2_grpc + + +class PipelineClient(object): + def __init__(self): + self._channel = None + + def connect(self, endpoint): + self._channel = grpc.insecure_channel(endpoint) + self._stub = pipeline_service_pb2_grpc.PipelineServiceStub( + self._channel) + + def _pack_data_for_infer(self, feed_dict): + req = pipeline_service_pb2.Request() + for key, value in feed_dict.items(): + if not isinstance(value, str): + raise TypeError("only str type is supported.") + req.key.append(key) + req.value.append(value) + return req + + def predict(self, feed_dict, fetch): + if not isinstance(feed_dict, dict): + raise TypeError( + "feed must be dict type with format: {name: value}.") + if not isinstance(fetch, list): + raise TypeError( + "fetch_with_type must be list type with format: [name].") + req = self._pack_data_for_infer(feed_dict) + resp = self._stub.inference(req) + if resp.ecode != 0: + return {"ecode": resp.ecode, "error_info": resp.error_info} + fetch_map = {"ecode": resp.ecode} + for idx, key in enumerate(resp.key): + if key not in fetch: + continue + fetch_map[key] = resp.value[idx] + return fetch_map diff --git a/python/pipeline/proto/__init__.py b/python/pipeline/proto/__init__.py new file mode 100644 index 00000000..abf198b9 --- /dev/null +++ b/python/pipeline/proto/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/pipeline/util.py b/python/pipeline/util.py new file mode 100644 index 00000000..a24c1a05 --- /dev/null +++ b/python/pipeline/util.py @@ -0,0 +1,25 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + + +class NameGenerator(object): + def __init__(self, prefix): + self._idx = -1 + self._prefix = prefix + + def next(self): + self._idx += 1 + return "{}{}".format(self._prefix, self._idx) -- GitLab