未验证 提交 52663edf 编写于 作者: B Bin Lu 提交者: GitHub

Update retrieval.py

上级 35d2be15
...@@ -124,6 +124,13 @@ def cal_feature(evaler, name='gallery'): ...@@ -124,6 +124,13 @@ def cal_feature(evaler, name='gallery'):
feas_norm = paddle.sqrt( feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm) batch_feas = paddle.divide(batch_feas, feas_norm)
# do binarize
if self.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
if self.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32")
if all_feas is None: if all_feas is None:
all_feas = batch_feas all_feas = batch_feas
...@@ -135,8 +142,10 @@ def cal_feature(evaler, name='gallery'): ...@@ -135,8 +142,10 @@ def cal_feature(evaler, name='gallery'):
all_image_id = paddle.concat([all_image_id, batch[1]]) all_image_id = paddle.concat([all_image_id, batch[1]])
if has_unique_id: if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]]) all_unique_id = paddle.concat([all_unique_id, batch[2]])
if evaler.use_dali: if evaler.use_dali:
dataloader_tmp.reset() dataloader_tmp.reset()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
feat_list = [] feat_list = []
img_id_list = [] img_id_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册