提交 f9b4e0f0 编写于 作者: W weishengyu

dbg

上级 d954f738
...@@ -335,37 +335,37 @@ class Trainer(object): ...@@ -335,37 +335,37 @@ class Trainer(object):
name='gallery') name='gallery')
query_feas, query_img_id, query_camera_id = self._cal_feature( query_feas, query_img_id, query_camera_id = self._cal_feature(
name='query') name='query')
gallery_img_id = gallery_img_id.t() gallery_img_id = gallery_img_id
if gallery_camera_id is not None: # if gallery_camera_id is not None:
gallery_camera_id = gallery_camera_id.t() # gallery_camera_id = gallery_camera_id
# step2. do evaluation # step2. do evaluation
sim_block_size = self.config["Global"].get("sim_block_size", 1) sim_block_size = self.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size) sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if not len(query_feas) % sim_block_size: if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size) sections.append(len(query_feas) % sim_block_size)
fea_blocks = paddle.split(query_feas, num_or_sections=sections) fea_blocks = paddle.split(query_feas, num_or_sections=sections)
if query_camera_id is not None: # if query_camera_id is not None:
camera_id_blocks = paddle.split( # camera_id_blocks = paddle.split(
query_camera_id, num_or_sections=sections) # query_camera_id, num_or_sections=sections)
image_id_blocks = paddle.split(query_img_id, num_or_sections=sections) # image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
metric_key = None metric_key = None
for block_idx, block_fea in enumerate(fea_blocks): for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul( similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True) block_fea, gallery_feas, transpose_y=True)
image_id_block = image_id_blocks[block_idx] # image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block != gallery_img_id) # image_id_mask = (image_id_block != gallery_img_id)
similarity_matrix = similarity_matrix.masked_select(image_id_mask) # similarity_matrix = similarity_matrix.masked_select(image_id_mask)
if query_camera_id is not None: # if query_camera_id is not None:
camera_id_block = camera_id_blocks[block_idx] # camera_id_block = camera_id_blocks[block_idx]
camera_id_mask = (camera_id_block != gallery_camera_id) # camera_id_mask = (camera_id_block != gallery_camera_id)
similarity_matrix = similarity_matrix.masked_select( # similarity_matrix = similarity_matrix.masked_select(
camera_id_mask) # camera_id_mask)
if similarity_matrix is None: if similarity_matrix is None:
cum_similarity_matrix = similarity_matrix cum_similarity_matrix = similarity_matrix
else: else:
cum_similarity_matrix = paddle.concat(cum_similarity_matrix, cum_similarity_matrix = paddle.concat(
similarity_matrix) [cum_similarity_matrix, similarity_matrix], axis=0)
# calc metric # calc metric
if self.eval_metric_func is not None: if self.eval_metric_func is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册