提交 a67cdaa1 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix rec post and visualizer

上级 3579f5a6
......@@ -42,7 +42,7 @@ def split_datafile(data_file, image_root, delimiter="\t"):
for i, line in enumerate(lines):
line = line.strip().split(delimiter)
image_file = os.path.join(image_root, line[0])
image_doc = line[1]
image_doc = line[1]
gallery_images.append(image_file)
gallery_docs.append(image_doc)
......@@ -57,28 +57,34 @@ class GalleryBuilder(object):
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.build(config['IndexProcess'])
def build(self, config):
'''
build index from scratch
'''
gallery_images, gallery_docs = split_datafile(config['data_file'],
config['image_root'], config['delimiter'])
gallery_images, gallery_docs = split_datafile(
config['data_file'], config['image_root'], config['delimiter'])
# extract gallery features
gallery_features = np.zeros([len(gallery_images),
config['embedding_size']], dtype=np.float32)
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32)
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)[:, :, ::-1]
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
rec_feat = self.rec_predictor.predict(img)
gallery_features[i,:] = rec_feat
gallery_features[i, :] = rec_feat
# train index
self.Searcher = Graph_Index(dist_type=config['dist_type'])
self.Searcher.build(gallery_vectors=gallery_features, gallery_docs=gallery_docs,
pq_size=config['pq_size'], index_path=config['index_path'])
self.Searcher = Graph_Index(dist_type=config['dist_type'])
self.Searcher.build(
gallery_vectors=gallery_features,
gallery_docs=gallery_docs,
pq_size=config['pq_size'],
index_path=config['index_path'])
def main(config):
system_builder = GalleryBuilder(config)
......
......@@ -46,22 +46,38 @@ class SystemPredictor(object):
dist_type=config['IndexProcess']['dist_type'])
self.Searcher.load(config['IndexProcess']['index_path'])
def append_self(self, results, shape):
results.append({
"class_id": 0,
"score": 1.0,
"bbox": np.array([0, 0, shape[1], shape[0]]),
"label_name": "foreground",
})
return results
def predict(self, img):
output = []
results = self.det_predictor.predict(img)
# add the whole image for recognition
results = self.append_self(results, img.shape)
for result in results:
preds = {}
xmin, ymin, xmax, ymax = result["bbox"].astype("int")
crop_img = img[ymin:ymax, xmin:xmax, :].copy()
rec_results = self.rec_predictor.predict(crop_img)
#preds["feature"] = rec_results
preds["bbox"] = [xmin, ymin, xmax, ymax]
scores, docs = self.Searcher.search(
query=rec_results,
return_k=self.return_k,
search_budget=self.search_budget)
preds["rec_docs"] = docs
preds["rec_scores"] = scores
# just top-1 result will be returned for the final
if scores[0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = docs[0]
preds["rec_scores"] = scores[0]
else:
preds["rec_docs"] = None
preds["rec_scores"] = 0.0
output.append(preds)
return output
......@@ -75,7 +91,7 @@ def main(config):
for idx, image_file in enumerate(image_list):
img = cv2.imread(image_file)[:, :, ::-1]
output = system_predictor.predict(img)
draw_bbox_results(img[:, :, ::-1], output, image_file)
draw_bbox_results(img, output, image_file)
print(output)
return
......
......@@ -15,18 +15,45 @@
import os
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
def draw_bbox_results(image, results, input_path, save_dir=None):
def draw_bbox_results(image,
results,
input_path,
font_path="./utils/simfang.ttf",
save_dir=None):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(font_path, 20, encoding="utf-8")
color = (0, 255, 0)
for result in results:
[xmin, ymin, xmax, ymax] = result["bbox"]
# empty results
if result["rec_docs"] is None:
continue
xmin, ymin, xmax, ymax = result["bbox"]
text = "{}, {:.2f}".format(result["rec_docs"], result["rec_scores"])
th = 20
tw = int(len(result["rec_docs"]) * 20) + 60
start_y = max(0, ymin - th)
draw.rectangle(
[(xmin + 1, start_y), (xmin + tw + 1, start_y + th)],
outline=color)
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0),
2)
draw.text((xmin + 1, start_y), text, fill=color, font=font)
draw.rectangle(
[(xmin, ymin), (xmax, ymax)], outline=(255, 0, 0), width=2)
image_name = os.path.basename(input_path)
if save_dir is None:
save_dir = "output"
os.makedirs(save_dir, exist_ok=True)
output_path = os.path.join(save_dir, image_name)
cv2.imwrite(output_path, image)
image.save(output_path, quality=95)
return np.array(image)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册