diff --git a/paddlecv/configs/system/PP-ShiTu.yml b/paddlecv/configs/system/PP-ShiTu.yml index 772c3bfcfb28d7f0833ef221e9318f031b2eebc6..17b62c1b9e27f71874ddf82dd64dc60cf98af220 100644 --- a/paddlecv/configs/system/PP-ShiTu.yml +++ b/paddlecv/configs/system/PP-ShiTu.yml @@ -4,7 +4,6 @@ ENV: cpu_threads: 1 trt_use_static: False save_img: True - save_res: True return_res: True print_res: True @@ -62,9 +61,9 @@ MODEL: PostProcess: - NormalizeFeature: - Index: - index_method: "HNSW32" # supported: HNSW32, IVF, Flat dist_type: "IP" - index_dir: "./drink_dataset_v1.0/index" + vector_path: "paddlecv://dict/pp-shitu/drink_dataset_v1.0_vector.index" + id_map_path: "paddlecv://dict/pp-shitu/drink_dataset_v1.0_id_map.pkl" score_thres: 0.5 - NMS4Rec: thresh: 0.05 diff --git a/paddlecv/configs/system/PP-ShiTuV2.yml b/paddlecv/configs/system/PP-ShiTuV2.yml index 275decce82bf2168a073c1689daf038be1cfd6fa..e46437c3851db65d34c521f1de52000b92a2b717 100644 --- a/paddlecv/configs/system/PP-ShiTuV2.yml +++ b/paddlecv/configs/system/PP-ShiTuV2.yml @@ -4,7 +4,6 @@ ENV: cpu_threads: 1 trt_use_static: False save_img: True - save_res: True return_res: True print_res: True @@ -62,9 +61,9 @@ MODEL: PostProcess: - NormalizeFeature: - Index: - index_method: "HNSW32" # supported: HNSW32, IVF, Flat dist_type: "IP" - index_dir: "./drink_dataset_v2.0/index" + vector_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_vector.index" + id_map_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_id_map.pkl" score_thres: 0.5 - NMS4Rec: thresh: 0.05 diff --git a/paddlecv/configs/unittest/test_feature_extraction.yml b/paddlecv/configs/unittest/test_feature_extraction.yml index 7de7b984274f60420590d24eaa7f6321deb59d31..b5c12fb0d152f54425bb6e9c4ad66eb964dcb7e5 100644 --- a/paddlecv/configs/unittest/test_feature_extraction.yml +++ b/paddlecv/configs/unittest/test_feature_extraction.yml @@ -31,9 +31,9 @@ MODEL: PostProcess: - NormalizeFeature: - Index: - index_method: "HNSW32" # supported: HNSW32, IVF, Flat dist_type: "IP" - index_dir: "./drink_dataset_v2.0/index" + vector_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_vector.index" + id_map_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_id_map.pkl" score_thres: 0.5 Inputs: - input.image diff --git a/paddlecv/demo/drink_dataset_v2.0_test_100.jpeg b/paddlecv/demo/drink_dataset_v2.0_test_100.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..e5f845ed2842f67352ffda26380b59a35d0449f9 Binary files /dev/null and b/paddlecv/demo/drink_dataset_v2.0_test_100.jpeg differ diff --git a/paddlecv/ppcv/ops/models/feature_extraction/postprocess.py b/paddlecv/ppcv/ops/models/feature_extraction/postprocess.py index dd687bbb257d356295caa635c6ef4855d4ffcaf2..0713d81c2fd6433681ab81d3e8a740efb0a307f6 100644 --- a/paddlecv/ppcv/ops/models/feature_extraction/postprocess.py +++ b/paddlecv/ppcv/ops/models/feature_extraction/postprocess.py @@ -18,6 +18,7 @@ import numpy as np import faiss import pickle +from ppcv.utils.download import get_dict_path from ppcv.utils.logger import setup_logger logger = setup_logger('FeatureExtraction') @@ -37,17 +38,13 @@ class NormalizeFeature(object): class Index(object): def __init__(self, - index_method, - index_dir, + vector_path, + id_map_path, dist_type, hamming_radius=None, score_thres=None): - vector_path = os.path.join(index_dir, "vector.index") - id_map_path = os.path.join(index_dir, "id_map.pkl") - if not os.path.exists(vector_path) or not os.path.exists(id_map_path): - msg = "The directory \"index_dir\" must contain files \"vector.index\", and \"id_map.pkl\". Please check again!" - logger.error(msg) - raise Exception(msg) + vector_path = get_dict_path(vector_path) + id_map_path = get_dict_path(id_map_path) if dist_type == "hamming": self.searcher = faiss.read_index_binary(vector_path) diff --git a/paddlecv/ppcv/ops/output/feature_extraction.py b/paddlecv/ppcv/ops/output/feature_extraction.py index 6336fa0adc8410b1b3c7f6579968659b6d60920f..62c0259fe0f34fcc1062c14232dab3e847df1bf5 100644 --- a/paddlecv/ppcv/ops/output/feature_extraction.py +++ b/paddlecv/ppcv/ops/output/feature_extraction.py @@ -44,6 +44,9 @@ class FeatureOutput(OutputBaseOp): if self.print_res: msg = " ".join([f"{key}: {res[key]}" for key in res]) logger.info(msg) + if self.save_res: + msg = "The FeatureOutput op does not yet support to save prediction results." + logger.warning(msg) if self.return_res: total_res.append(res) if self.return_res: diff --git a/paddlecv/tests/test_feature_extraction.py b/paddlecv/tests/test_feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..3d26f1a91ce1ace566fac2853c80d0c9e3e067dc --- /dev/null +++ b/paddlecv/tests/test_feature_extraction.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +parent = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(parent, '../'))) + +import cv2 +import unittest +import yaml +import argparse + +from ppcv.core.workspace import global_config +from ppcv.core.config import ConfigParser + + +class TestFeatureExtraction(unittest.TestCase): + def setUp(self): + self.config = 'configs/unittest/test_feature_extraction.yml' + self.input = 'demo/drink_dataset_v2.0_test_100.jpeg' + self.cfg_dict = dict(config=self.config, input=self.input) + cfg = argparse.Namespace(**self.cfg_dict) + config = ConfigParser(cfg) + config.print_cfg() + self.model_cfg, self.env_cfg = config.parse() + + def test_classification(self): + img = cv2.imread(self.input)[:, :, ::-1] + inputs = [ + { + "input.image": img + }, + ] + op_name = list(self.model_cfg[0].keys())[0] + op = global_config[op_name](self.model_cfg[0][op_name], + self.env_cfg) + result = op(inputs) + + +if __name__ == '__main__': + unittest.main()