未验证 提交 089dea74 编写于 作者: T Tingquan Gao 提交者: GitHub

add feature extraction unittest (#5682)

* add feature extraction unittest

* support to download index files automatically

* warning when save_res is set to True in FeatureExtraction

* rename id map file
上级 872d7998
......@@ -4,7 +4,6 @@ ENV:
cpu_threads: 1
trt_use_static: False
save_img: True
save_res: True
return_res: True
print_res: True
......@@ -62,9 +61,9 @@ MODEL:
PostProcess:
- NormalizeFeature:
- Index:
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
dist_type: "IP"
index_dir: "./drink_dataset_v1.0/index"
vector_path: "paddlecv://dict/pp-shitu/drink_dataset_v1.0_vector.index"
id_map_path: "paddlecv://dict/pp-shitu/drink_dataset_v1.0_id_map.pkl"
score_thres: 0.5
- NMS4Rec:
thresh: 0.05
......
......@@ -4,7 +4,6 @@ ENV:
cpu_threads: 1
trt_use_static: False
save_img: True
save_res: True
return_res: True
print_res: True
......@@ -62,9 +61,9 @@ MODEL:
PostProcess:
- NormalizeFeature:
- Index:
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
dist_type: "IP"
index_dir: "./drink_dataset_v2.0/index"
vector_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_vector.index"
id_map_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_id_map.pkl"
score_thres: 0.5
- NMS4Rec:
thresh: 0.05
......
......@@ -31,9 +31,9 @@ MODEL:
PostProcess:
- NormalizeFeature:
- Index:
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
dist_type: "IP"
index_dir: "./drink_dataset_v2.0/index"
vector_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_vector.index"
id_map_path: "paddlecv://dict/pp-shitu/drink_dataset_v2.0_id_map.pkl"
score_thres: 0.5
Inputs:
- input.image
......
......@@ -18,6 +18,7 @@ import numpy as np
import faiss
import pickle
from ppcv.utils.download import get_dict_path
from ppcv.utils.logger import setup_logger
logger = setup_logger('FeatureExtraction')
......@@ -37,17 +38,13 @@ class NormalizeFeature(object):
class Index(object):
def __init__(self,
index_method,
index_dir,
vector_path,
id_map_path,
dist_type,
hamming_radius=None,
score_thres=None):
vector_path = os.path.join(index_dir, "vector.index")
id_map_path = os.path.join(index_dir, "id_map.pkl")
if not os.path.exists(vector_path) or not os.path.exists(id_map_path):
msg = "The directory \"index_dir\" must contain files \"vector.index\", and \"id_map.pkl\". Please check again!"
logger.error(msg)
raise Exception(msg)
vector_path = get_dict_path(vector_path)
id_map_path = get_dict_path(id_map_path)
if dist_type == "hamming":
self.searcher = faiss.read_index_binary(vector_path)
......
......@@ -44,6 +44,9 @@ class FeatureOutput(OutputBaseOp):
if self.print_res:
msg = " ".join([f"{key}: {res[key]}" for key in res])
logger.info(msg)
if self.save_res:
msg = "The FeatureOutput op does not yet support to save prediction results."
logger.warning(msg)
if self.return_res:
total_res.append(res)
if self.return_res:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
parent = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
import cv2
import unittest
import yaml
import argparse
from ppcv.core.workspace import global_config
from ppcv.core.config import ConfigParser
class TestFeatureExtraction(unittest.TestCase):
def setUp(self):
self.config = 'configs/unittest/test_feature_extraction.yml'
self.input = 'demo/drink_dataset_v2.0_test_100.jpeg'
self.cfg_dict = dict(config=self.config, input=self.input)
cfg = argparse.Namespace(**self.cfg_dict)
config = ConfigParser(cfg)
config.print_cfg()
self.model_cfg, self.env_cfg = config.parse()
def test_classification(self):
img = cv2.imread(self.input)[:, :, ::-1]
inputs = [
{
"input.image": img
},
]
op_name = list(self.model_cfg[0].keys())[0]
op = global_config[op_name](self.model_cfg[0][op_name],
self.env_cfg)
result = op(inputs)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册