提交 f9b4e0f0 编写于 作者: W weishengyu

dbg

上级 d954f738
......@@ -335,37 +335,37 @@ class Trainer(object):
name='gallery')
query_feas, query_img_id, query_camera_id = self._cal_feature(
name='query')
gallery_img_id = gallery_img_id.t()
if gallery_camera_id is not None:
gallery_camera_id = gallery_camera_id.t()
gallery_img_id = gallery_img_id
# if gallery_camera_id is not None:
# gallery_camera_id = gallery_camera_id
# step2. do evaluation
sim_block_size = self.config["Global"].get("sim_block_size", 1)
sim_block_size = self.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if not len(query_feas) % sim_block_size:
if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size)
fea_blocks = paddle.split(query_feas, num_or_sections=sections)
if query_camera_id is not None:
camera_id_blocks = paddle.split(
query_camera_id, num_or_sections=sections)
image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
# if query_camera_id is not None:
# camera_id_blocks = paddle.split(
# query_camera_id, num_or_sections=sections)
# image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
metric_key = None
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block != gallery_img_id)
similarity_matrix = similarity_matrix.masked_select(image_id_mask)
if query_camera_id is not None:
camera_id_block = camera_id_blocks[block_idx]
camera_id_mask = (camera_id_block != gallery_camera_id)
similarity_matrix = similarity_matrix.masked_select(
camera_id_mask)
# image_id_block = image_id_blocks[block_idx]
# image_id_mask = (image_id_block != gallery_img_id)
# similarity_matrix = similarity_matrix.masked_select(image_id_mask)
# if query_camera_id is not None:
# camera_id_block = camera_id_blocks[block_idx]
# camera_id_mask = (camera_id_block != gallery_camera_id)
# similarity_matrix = similarity_matrix.masked_select(
# camera_id_mask)
if similarity_matrix is None:
cum_similarity_matrix = similarity_matrix
else:
cum_similarity_matrix = paddle.concat(cum_similarity_matrix,
similarity_matrix)
cum_similarity_matrix = paddle.concat(
[cum_similarity_matrix, similarity_matrix], axis=0)
# calc metric
if self.eval_metric_func is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册