From c9c7c18c72015f11ee67482cada5d8e08da8c01a Mon Sep 17 00:00:00 2001 From: Tingquan Gao Date: Thu, 6 May 2021 13:55:53 +0800 Subject: [PATCH] Fix HubServing demo (#709) --- deploy/hubserving/readme.md | 2 +- deploy/hubserving/readme_en.md | 2 +- tools/test_hubserving.py | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/deploy/hubserving/readme.md b/deploy/hubserving/readme.md index 28dde431..c3748584 100644 --- a/deploy/hubserving/readme.md +++ b/deploy/hubserving/readme.md @@ -118,7 +118,7 @@ hub serving start -c deploy/hubserving/clas/config.json - **top_k**:[**可选**] 返回前 `top_k` 个 `score` ,默认为 `1`。 访问示例: -```python tools/test_hubserving.py http://127.0.0.1:8866/predict/clas_system ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG 5``` +```python tools/test_hubserving.py --server_url http://127.0.0.1:8866/predict/clas_system --image_file ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG 5``` ### 返回结果格式说明 返回结果为列表(list),包含top-k个分类结果,以及对应的得分,还有此图片预测耗时,具体如下: diff --git a/deploy/hubserving/readme_en.md b/deploy/hubserving/readme_en.md index 5be4fe6b..0f34fd34 100644 --- a/deploy/hubserving/readme_en.md +++ b/deploy/hubserving/readme_en.md @@ -121,7 +121,7 @@ Two parameters need to be passed to the script: **Eg.** ```shell -python tools/test_hubserving.py http://127.0.0.1:8866/predict/clas_system ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG 5 +python tools/test_hubserving.py --server_url http://127.0.0.1:8866/predict/clas_system --image_file ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG 5 ``` ### Returned result format diff --git a/tools/test_hubserving.py b/tools/test_hubserving.py index 5c2820df..41f1287c 100644 --- a/tools/test_hubserving.py +++ b/tools/test_hubserving.py @@ -87,11 +87,12 @@ def main(args): predict_time += elapse for number, result_list in enumerate(batch_result_list): - all_score += result_list[0]["score"] - result_str = ", ".join([ - "{}: {:.2f}".format(r["cls_id"], r["score"]) - for r in result_list - ]) + all_score += result_list["scores"][0] + result_str = "" + for i in range(len(result_list["clas_ids"])): + result_str += "{}: {:.2f}\t".format( + result_list["clas_ids"][i], + result_list["scores"][i]) logger.info("File:{}, The top-{} result(s): {}".format( img_name_list[number], args.top_k, result_str)) -- GitLab