提交 ad3bd3a3 编写于 作者: L lubin10

fixbug bs=1 in build_gallery

上级 c68d411c
i have already fix the bug
......@@ -71,14 +71,26 @@ class GalleryBuilder(object):
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
rec_feat = self.rec_predictor.predict(img)
gallery_features[i, :] = rec_feat
batch_img.append(img)
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
# train index
self.Searcher = Graph_Index(dist_type=config['dist_type'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册