diff --git a/python/examples/imdb/README.md b/python/examples/imdb/README.md index 036bd9d79d216c44db1adbbb867daff94aa5584b..9106dd872ae6ea0a6bd3ee6f732dacf6e8b44498 100644 --- a/python/examples/imdb/README.md +++ b/python/examples/imdb/README.md @@ -10,3 +10,4 @@ cat test.data | python test_client.py > result ``` python test_client_multithread.py inference.conf test.data 4 > result ``` +batch clienit diff --git a/python/examples/imdb/test_client_batch.py b/python/examples/imdb/test_client_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..6686df272d62fe62315417470d2ce747a4b3c9ff --- /dev/null +++ b/python/examples/imdb/test_client_batch.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving import Client +import sys +import subprocess +from multiprocessing import Pool +import time + + +def predict_for_batch(batch_size=4): + client = Client() + client.load_client_config(conf_file) + client.connect(["127.0.0.1:8010"]) + start = time.time() + feed_batch = [] + for line in sys.stdin: + group = line.strip().split() + words = [int(x) for x in group[1:int(group[0])]] + label = [int(group[-1])] + feed = {"words": words, "label": label} + fetch = ["acc", "cost", "prediction"] + feed_batch.append(feed) + if len(feed_batch) == batch_size: + fetch_batch = client.predict_for_batch( + feed_batch=feed_batch, fetch=fetch) + for i in range(batch_size): + print("{} {}".format(fetch_batch[i]["prediction"][1], + feed_batch[i]["label"][0])) + feed_batch = [] + cost = time.time() - start + print("total cost : {}".format(cost)) + print(time.time()) + + +if __name__ == '__main__': + conf_file = sys.argv[1] + batch_size = int(sys.argv[2]) + predict_for_batch(batch_size)