未验证 提交 578e4d74 编写于 作者: D dyning 提交者: GitHub

Merge pull request #917 from littletomatodonkey/dev/fix_qs_demo

fix rec post and visualizer
...@@ -4,8 +4,8 @@ Global: ...@@ -4,8 +4,8 @@ Global:
rec_inference_model_dir: "./models/cartoon_rec_ResNet50_iCartoon_v1.0_infer/" rec_inference_model_dir: "./models/cartoon_rec_ResNet50_iCartoon_v1.0_infer/"
batch_size: 1 batch_size: 1
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
threshold: 0.5 threshold: 0.2
max_det_results: 1 max_det_results: 5
labe_list: labe_list:
- foreground - foreground
...@@ -53,3 +53,4 @@ IndexProcess: ...@@ -53,3 +53,4 @@ IndexProcess:
search_budget: 100 search_budget: 100
return_k: 5 return_k: 5
dist_type: "IP" dist_type: "IP"
score_thres: 0.5
...@@ -4,8 +4,8 @@ Global: ...@@ -4,8 +4,8 @@ Global:
rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/" rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/"
batch_size: 1 batch_size: 1
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
threshold: 0.5 threshold: 0.2
max_det_results: 1 max_det_results: 5
labe_list: labe_list:
- foreground - foreground
...@@ -52,3 +52,4 @@ IndexProcess: ...@@ -52,3 +52,4 @@ IndexProcess:
search_budget: 100 search_budget: 100
return_k: 5 return_k: 5
dist_type: "IP" dist_type: "IP"
score_thres: 0.5
...@@ -5,7 +5,7 @@ Global: ...@@ -5,7 +5,7 @@ Global:
batch_size: 1 batch_size: 1
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
threshold: 0.2 threshold: 0.2
max_det_results: 1 max_det_results: 5
labe_list: labe_list:
- foreground - foreground
...@@ -52,3 +52,4 @@ IndexProcess: ...@@ -52,3 +52,4 @@ IndexProcess:
search_budget: 100 search_budget: 100
return_k: 5 return_k: 5
dist_type: "IP" dist_type: "IP"
score_thres: 0.5
...@@ -4,8 +4,8 @@ Global: ...@@ -4,8 +4,8 @@ Global:
rec_inference_model_dir: "./models/vehicle_cls_ResNet50_CompCars_v1.0_infer/" rec_inference_model_dir: "./models/vehicle_cls_ResNet50_CompCars_v1.0_infer/"
batch_size: 1 batch_size: 1
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
threshold: 0.5 threshold: 0.2
max_det_results: 1 max_det_results: 5
labe_list: labe_list:
- foreground - foreground
...@@ -54,3 +54,4 @@ IndexProcess: ...@@ -54,3 +54,4 @@ IndexProcess:
search_budget: 100 search_budget: 100
return_k: 5 return_k: 5
dist_type: "IP" dist_type: "IP"
score_thres: 0.5
...@@ -42,7 +42,7 @@ def split_datafile(data_file, image_root, delimiter="\t"): ...@@ -42,7 +42,7 @@ def split_datafile(data_file, image_root, delimiter="\t"):
for i, line in enumerate(lines): for i, line in enumerate(lines):
line = line.strip().split(delimiter) line = line.strip().split(delimiter)
image_file = os.path.join(image_root, line[0]) image_file = os.path.join(image_root, line[0])
image_doc = line[1] image_doc = line[1]
gallery_images.append(image_file) gallery_images.append(image_file)
gallery_docs.append(image_doc) gallery_docs.append(image_doc)
...@@ -57,28 +57,34 @@ class GalleryBuilder(object): ...@@ -57,28 +57,34 @@ class GalleryBuilder(object):
assert 'IndexProcess' in config.keys(), "Index config not found ... " assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.build(config['IndexProcess']) self.build(config['IndexProcess'])
def build(self, config): def build(self, config):
''' '''
build index from scratch build index from scratch
''' '''
gallery_images, gallery_docs = split_datafile(config['data_file'], gallery_images, gallery_docs = split_datafile(
config['image_root'], config['delimiter']) config['data_file'], config['image_root'], config['delimiter'])
# extract gallery features # extract gallery features
gallery_features = np.zeros([len(gallery_images), gallery_features = np.zeros(
config['embedding_size']], dtype=np.float32) [len(gallery_images), config['embedding_size']], dtype=np.float32)
for i, image_file in enumerate(tqdm(gallery_images)): for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)[:, :, ::-1] img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
rec_feat = self.rec_predictor.predict(img) rec_feat = self.rec_predictor.predict(img)
gallery_features[i,:] = rec_feat gallery_features[i, :] = rec_feat
# train index # train index
self.Searcher = Graph_Index(dist_type=config['dist_type']) self.Searcher = Graph_Index(dist_type=config['dist_type'])
self.Searcher.build(gallery_vectors=gallery_features, gallery_docs=gallery_docs, self.Searcher.build(
pq_size=config['pq_size'], index_path=config['index_path']) gallery_vectors=gallery_features,
gallery_docs=gallery_docs,
pq_size=config['pq_size'],
index_path=config['index_path'])
def main(config): def main(config):
system_builder = GalleryBuilder(config) system_builder = GalleryBuilder(config)
......
...@@ -46,22 +46,38 @@ class SystemPredictor(object): ...@@ -46,22 +46,38 @@ class SystemPredictor(object):
dist_type=config['IndexProcess']['dist_type']) dist_type=config['IndexProcess']['dist_type'])
self.Searcher.load(config['IndexProcess']['index_path']) self.Searcher.load(config['IndexProcess']['index_path'])
def append_self(self, results, shape):
results.append({
"class_id": 0,
"score": 1.0,
"bbox": np.array([0, 0, shape[1], shape[0]]),
"label_name": "foreground",
})
return results
def predict(self, img): def predict(self, img):
output = [] output = []
results = self.det_predictor.predict(img) results = self.det_predictor.predict(img)
# add the whole image for recognition
results = self.append_self(results, img.shape)
for result in results: for result in results:
preds = {} preds = {}
xmin, ymin, xmax, ymax = result["bbox"].astype("int") xmin, ymin, xmax, ymax = result["bbox"].astype("int")
crop_img = img[ymin:ymax, xmin:xmax, :].copy() crop_img = img[ymin:ymax, xmin:xmax, :].copy()
rec_results = self.rec_predictor.predict(crop_img) rec_results = self.rec_predictor.predict(crop_img)
#preds["feature"] = rec_results
preds["bbox"] = [xmin, ymin, xmax, ymax] preds["bbox"] = [xmin, ymin, xmax, ymax]
scores, docs = self.Searcher.search( scores, docs = self.Searcher.search(
query=rec_results, query=rec_results,
return_k=self.return_k, return_k=self.return_k,
search_budget=self.search_budget) search_budget=self.search_budget)
preds["rec_docs"] = docs # just top-1 result will be returned for the final
preds["rec_scores"] = scores if scores[0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = docs[0]
preds["rec_scores"] = scores[0]
else:
preds["rec_docs"] = None
preds["rec_scores"] = 0.0
output.append(preds) output.append(preds)
return output return output
...@@ -75,7 +91,7 @@ def main(config): ...@@ -75,7 +91,7 @@ def main(config):
for idx, image_file in enumerate(image_list): for idx, image_file in enumerate(image_list):
img = cv2.imread(image_file)[:, :, ::-1] img = cv2.imread(image_file)[:, :, ::-1]
output = system_predictor.predict(img) output = system_predictor.predict(img)
draw_bbox_results(img[:, :, ::-1], output, image_file) draw_bbox_results(img, output, image_file)
print(output) print(output)
return return
......
...@@ -15,18 +15,45 @@ ...@@ -15,18 +15,45 @@
import os import os
import numpy as np import numpy as np
import cv2 import cv2
from PIL import Image, ImageDraw, ImageFont
def draw_bbox_results(image, results, input_path, save_dir=None): def draw_bbox_results(image,
results,
input_path,
font_path="./utils/simfang.ttf",
save_dir=None):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(font_path, 20, encoding="utf-8")
color = (0, 255, 0)
for result in results: for result in results:
[xmin, ymin, xmax, ymax] = result["bbox"] # empty results
if result["rec_docs"] is None:
continue
xmin, ymin, xmax, ymax = result["bbox"]
text = "{}, {:.2f}".format(result["rec_docs"], result["rec_scores"])
th = 20
tw = int(len(result["rec_docs"]) * 20) + 60
start_y = max(0, ymin - th)
draw.rectangle(
[(xmin + 1, start_y), (xmin + tw + 1, start_y + th)],
outline=color)
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0), draw.text((xmin + 1, start_y), text, fill=color, font=font)
2)
draw.rectangle(
[(xmin, ymin), (xmax, ymax)], outline=(255, 0, 0), width=2)
image_name = os.path.basename(input_path) image_name = os.path.basename(input_path)
if save_dir is None: if save_dir is None:
save_dir = "output" save_dir = "output"
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
output_path = os.path.join(save_dir, image_name) output_path = os.path.join(save_dir, image_name)
cv2.imwrite(output_path, image)
image.save(output_path, quality=95)
return np.array(image)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册