From a83cd7662a4d04f8b238710f9b3d4964253bfd06 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 15 Jun 2020 21:01:18 +0800 Subject: [PATCH] multilang future add call back --- .../fit_a_line/test_multilang_client.py | 22 +++++++++++++++++-- python/paddle_serving_client/__init__.py | 12 +++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/examples/fit_a_line/test_multilang_client.py b/python/examples/fit_a_line/test_multilang_client.py index c2c58378..f85814a4 100644 --- a/python/examples/fit_a_line/test_multilang_client.py +++ b/python/examples/fit_a_line/test_multilang_client.py @@ -14,7 +14,10 @@ # pylint: disable=doc-string-missing from paddle_serving_client import MultiLangClient +import functools import sys +import time +import threading client = MultiLangClient() client.load_client_config(sys.argv[1]) @@ -26,7 +29,22 @@ test_reader = paddle.batch( paddle.dataset.uci_housing.test(), buf_size=500), batch_size=1) +complete_task_count = [0] +lock = threading.Lock() + + +def call_back(call_future, data): + fetch_map = call_future.result() + print("{} {}".format(fetch_map["price"][0], data[0][1][0])) + with lock: + complete_task_count[0] += 1 + + +task_count = 0 for data in test_reader(): future = client.predict(feed={"x": data[0][0]}, fetch=["price"], asyn=True) - fetch_map = future.result() - print("{} {}".format(fetch_map["price"][0], data[0][1][0])) + task_count += 1 + future.add_done_callback(functools.partial(call_back, data=data)) + +while complete_task_count[0] != task_count: + time.sleep(0.1) diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 9e329267..58ae09bc 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -431,7 +431,6 @@ class MultiLangClient(object): def _pack_feed_data(self, feed, fetch, is_python): req = multi_lang_general_model_service_pb2.Request() req.fetch_var_names.extend(fetch) - req.feed_var_names.extend(feed.keys()) req.is_python = is_python feed_batch = None if isinstance(feed, dict): @@ -440,6 +439,7 @@ class MultiLangClient(object): feed_batch = feed else: raise Exception("{} not support".format(type(feed))) + req.feed_var_names.extend(feed_batch[0].keys()) init_feed_names = False for feed_data in feed_batch: inst = multi_lang_general_model_service_pb2.FeedInst() @@ -516,6 +516,9 @@ class MultiLangClient(object): return unpack_resp + def get_feed_names(self): + return self.feed_names_ + def predict(self, feed, fetch, @@ -548,3 +551,10 @@ class MultiLangPredictFuture(object): def result(self): resp = self.call_future_.result() return self.callback_func_(resp) + + def add_done_callback(self, fn): + def __fn__(call_future): + assert call_future == self.call_future_ + fn(self) + + self.call_future_.add_done_callback(__fn__) -- GitLab