未验证 提交 38bdf650 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix deformable_detr src_mask bug (#7148)

上级 d08188d4
...@@ -448,7 +448,6 @@ class DeformableTransformer(nn.Layer): ...@@ -448,7 +448,6 @@ class DeformableTransformer(nn.Layer):
return {'backbone_num_channels': [i.channels for i in input_shape], } return {'backbone_num_channels': [i.channels for i in input_shape], }
def _get_valid_ratio(self, mask): def _get_valid_ratio(self, mask):
mask = mask.astype(paddle.float32)
_, H, W = mask.shape _, H, W = mask.shape
valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H
valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W
...@@ -477,18 +476,16 @@ class DeformableTransformer(nn.Layer): ...@@ -477,18 +476,16 @@ class DeformableTransformer(nn.Layer):
src = src.flatten(2).transpose([0, 2, 1]) src = src.flatten(2).transpose([0, 2, 1])
src_flatten.append(src) src_flatten.append(src)
if src_mask is not None: if src_mask is not None:
mask = F.interpolate( mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0]
src_mask.unsqueeze(0).astype(src.dtype),
size=(h, w))[0].astype('bool')
else: else:
mask = paddle.ones([bs, h, w], dtype='bool') mask = paddle.ones([bs, h, w])
valid_ratios.append(self._get_valid_ratio(mask)) valid_ratios.append(self._get_valid_ratio(mask))
pos_embed = self.position_embedding(mask).flatten(2).transpose( pos_embed = self.position_embedding(mask).flatten(2).transpose(
[0, 2, 1]) [0, 2, 1])
lvl_pos_embed = pos_embed + self.level_embed.weight[level].reshape( lvl_pos_embed = pos_embed + self.level_embed.weight[level].reshape(
[1, 1, -1]) [1, 1, -1])
lvl_pos_embed_flatten.append(lvl_pos_embed) lvl_pos_embed_flatten.append(lvl_pos_embed)
mask = mask.astype(src.dtype).flatten(1) mask = mask.flatten(1)
mask_flatten.append(mask) mask_flatten.append(mask)
src_flatten = paddle.concat(src_flatten, 1) src_flatten = paddle.concat(src_flatten, 1)
mask_flatten = paddle.concat(mask_flatten, 1) mask_flatten = paddle.concat(mask_flatten, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册