提交 ad3bd3a3 编写于 作者: L lubin10

fixbug bs=1 in build_gallery

上级 c68d411c
i have already fix the bug
...@@ -71,14 +71,26 @@ class GalleryBuilder(object): ...@@ -71,14 +71,26 @@ class GalleryBuilder(object):
gallery_features = np.zeros( gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32) [len(gallery_images), config['embedding_size']], dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)): for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.error("img empty, please check {}".format(image_file)) logger.error("img empty, please check {}".format(image_file))
exit() exit()
img = img[:, :, ::-1] img = img[:, :, ::-1]
rec_feat = self.rec_predictor.predict(img) batch_img.append(img)
gallery_features[i, :] = rec_feat
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
# train index # train index
self.Searcher = Graph_Index(dist_type=config['dist_type']) self.Searcher = Graph_Index(dist_type=config['dist_type'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册