提交 934de965 编写于 作者: W weishengyu

rename cam -> unique

上级 16718f08
...@@ -329,22 +329,22 @@ class Trainer(object): ...@@ -329,22 +329,22 @@ class Trainer(object):
self.model.eval() self.model.eval()
cum_similarity_matrix = None cum_similarity_matrix = None
# step1. build gallery # step1. build gallery
gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature( gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
name='gallery') name='gallery')
query_feas, query_img_id, query_camera_id = self._cal_feature( query_feas, query_img_id, query_query_id = self._cal_feature(
name='query') name='query')
gallery_img_id = gallery_img_id gallery_img_id = gallery_img_id
# if gallery_camera_id is not None: # if gallery_unique_id is not None:
# gallery_camera_id = gallery_camera_id # gallery_unique_id = gallery_unique_id
# step2. do evaluation # step2. do evaluation
sim_block_size = self.config["Global"].get("sim_block_size", 64) sim_block_size = self.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size) sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if len(query_feas) % sim_block_size: if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size) sections.append(len(query_feas) % sim_block_size)
fea_blocks = paddle.split(query_feas, num_or_sections=sections) fea_blocks = paddle.split(query_feas, num_or_sections=sections)
if query_camera_id is not None: if query_query_id is not None:
camera_id_blocks = paddle.split( query_id_blocks = paddle.split(
query_camera_id, num_or_sections=sections) query_query_id, num_or_sections=sections)
image_id_blocks = paddle.split( image_id_blocks = paddle.split(
query_img_id, num_or_sections=sections) query_img_id, num_or_sections=sections)
metric_key = None metric_key = None
...@@ -352,14 +352,14 @@ class Trainer(object): ...@@ -352,14 +352,14 @@ class Trainer(object):
for block_idx, block_fea in enumerate(fea_blocks): for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul( similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True) block_fea, gallery_feas, transpose_y=True)
if query_camera_id is not None: if query_query_id is not None:
camera_id_block = camera_id_blocks[block_idx] query_id_block = query_id_blocks[block_idx]
camera_id_mask = (camera_id_block != gallery_camera_id.t()) query_id_mask = (query_id_block != gallery_unique_id.t())
image_id_block = image_id_blocks[block_idx] image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block != gallery_img_id.t()) image_id_mask = (image_id_block != gallery_img_id.t())
keep_mask = paddle.logical_or(camera_id_mask, image_id_mask) keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype( similarity_matrix = similarity_matrix * keep_mask.astype(
"float32") "float32")
if cum_similarity_matrix is None: if cum_similarity_matrix is None:
...@@ -388,7 +388,7 @@ class Trainer(object): ...@@ -388,7 +388,7 @@ class Trainer(object):
def _cal_feature(self, name='gallery'): def _cal_feature(self, name='gallery'):
all_feas = None all_feas = None
all_image_id = None all_image_id = None
all_camera_id = None all_unique_id = None
if name == 'gallery': if name == 'gallery':
dataloader = self.gallery_dataloader dataloader = self.gallery_dataloader
elif name == 'query': elif name == 'query':
...@@ -396,13 +396,13 @@ class Trainer(object): ...@@ -396,13 +396,13 @@ class Trainer(object):
else: else:
raise RuntimeError("Only support gallery or query dataset") raise RuntimeError("Only support gallery or query dataset")
has_cam_id = False has_unique_id = False
for idx, batch in enumerate(dataloader( for idx, batch in enumerate(dataloader(
)): # load is very time-consuming )): # load is very time-consuming
batch = [paddle.to_tensor(x) for x in batch] batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]) batch[1] = batch[1].reshape([-1, 1])
if len(batch) == 3: if len(batch) == 3:
has_cam_id = True has_unique_id = True
batch[2] = batch[2].reshape([-1, 1]) batch[2] = batch[2].reshape([-1, 1])
out = self.model(batch[0], batch[1]) out = self.model(batch[0], batch[1])
batch_feas = out["features"] batch_feas = out["features"]
...@@ -416,30 +416,30 @@ class Trainer(object): ...@@ -416,30 +416,30 @@ class Trainer(object):
if all_feas is None: if all_feas is None:
all_feas = batch_feas all_feas = batch_feas
if has_cam_id: if has_unique_id:
all_camera_id = batch[2] all_unique_id = batch[2]
all_image_id = batch[1] all_image_id = batch[1]
else: else:
all_feas = paddle.concat([all_feas, batch_feas]) all_feas = paddle.concat([all_feas, batch_feas])
all_image_id = paddle.concat([all_image_id, batch[1]]) all_image_id = paddle.concat([all_image_id, batch[1]])
if has_cam_id: if has_unique_id:
all_camera_id = paddle.concat([all_camera_id, batch[2]]) all_unique_id = paddle.concat([all_unique_id, batch[2]])
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
feat_list = [] feat_list = []
img_id_list = [] img_id_list = []
cam_id_list = [] unique_id_list = []
paddle.distributed.all_gather(feat_list, all_feas) paddle.distributed.all_gather(feat_list, all_feas)
paddle.distributed.all_gather(img_id_list, all_image_id) paddle.distributed.all_gather(img_id_list, all_image_id)
all_feas = paddle.concat(feat_list, axis=0) all_feas = paddle.concat(feat_list, axis=0)
all_image_id = paddle.concat(img_id_list, axis=0) all_image_id = paddle.concat(img_id_list, axis=0)
if has_cam_id: if has_unique_id:
paddle.distributed.all_gather(cam_id_list, all_camera_id) paddle.distributed.all_gather(unique_id_list, all_unique_id)
all_camera_id = paddle.concat(cam_id_list, axis=0) all_unique_id = paddle.concat(unique_id_list, axis=0)
logger.info("Build {} done, all feat shape: {}, begin to eval..". logger.info("Build {} done, all feat shape: {}, begin to eval..".
format(name, all_feas.shape)) format(name, all_feas.shape))
return all_feas, all_image_id, all_camera_id return all_feas, all_image_id, all_unique_id
@paddle.no_grad() @paddle.no_grad()
def infer(self, ): def infer(self, ):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册