未验证 提交 0eeb077e 编写于 作者: S shangliang Xu 提交者: GitHub

[dino] update code for inverse sigmoid (#7715)

上级 1463f210
...@@ -257,7 +257,7 @@ class DINOTransformerDecoder(nn.Layer): ...@@ -257,7 +257,7 @@ class DINOTransformerDecoder(nn.Layer):
def forward(self, def forward(self,
tgt, tgt,
reference_points, ref_points_unact,
memory, memory,
memory_spatial_shapes, memory_spatial_shapes,
memory_level_start_index, memory_level_start_index,
...@@ -272,9 +272,9 @@ class DINOTransformerDecoder(nn.Layer): ...@@ -272,9 +272,9 @@ class DINOTransformerDecoder(nn.Layer):
output = tgt output = tgt
intermediate = [] intermediate = []
inter_ref_bboxes = [] inter_ref_bboxes_unact = []
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
reference_points_input = reference_points.unsqueeze( reference_points_input = F.sigmoid(ref_points_unact).unsqueeze(
2) * valid_ratios.tile([1, 1, 2]).unsqueeze(1) 2) * valid_ratios.tile([1, 1, 2]).unsqueeze(1)
query_pos_embed = get_sine_pos_embed( query_pos_embed = get_sine_pos_embed(
reference_points_input[..., 0, :], self.hidden_dim // 2) reference_points_input[..., 0, :], self.hidden_dim // 2)
...@@ -284,19 +284,19 @@ class DINOTransformerDecoder(nn.Layer): ...@@ -284,19 +284,19 @@ class DINOTransformerDecoder(nn.Layer):
memory_spatial_shapes, memory_level_start_index, memory_spatial_shapes, memory_level_start_index,
attn_mask, memory_mask, query_pos_embed) attn_mask, memory_mask, query_pos_embed)
inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid( inter_ref_bbox_unact = bbox_head[i](output) + ref_points_unact
reference_points))
if self.return_intermediate: if self.return_intermediate:
intermediate.append(self.norm(output)) intermediate.append(self.norm(output))
inter_ref_bboxes.append(inter_ref_bbox) inter_ref_bboxes_unact.append(inter_ref_bbox_unact)
reference_points = inter_ref_bbox.detach() ref_points_unact = inter_ref_bbox_unact.detach()
if self.return_intermediate: if self.return_intermediate:
return paddle.stack(intermediate), paddle.stack(inter_ref_bboxes) return paddle.stack(intermediate), paddle.stack(
inter_ref_bboxes_unact)
return output, reference_points return output, ref_points_unact
@register @register
...@@ -417,7 +417,8 @@ class DINOTransformer(nn.Layer): ...@@ -417,7 +417,8 @@ class DINOTransformer(nn.Layer):
linear_init_(self.enc_output[0]) linear_init_(self.enc_output[0])
xavier_uniform_(self.enc_output[0].weight) xavier_uniform_(self.enc_output[0].weight)
normal_(self.level_embed.weight) normal_(self.level_embed.weight)
xavier_uniform_(self.tgt_embed.weight) if self.learnt_init_query:
xavier_uniform_(self.tgt_embed.weight)
xavier_uniform_(self.query_pos_head.layers[0].weight) xavier_uniform_(self.query_pos_head.layers[0].weight)
xavier_uniform_(self.query_pos_head.layers[1].weight) xavier_uniform_(self.query_pos_head.layers[1].weight)
for l in self.input_proj: for l in self.input_proj:
...@@ -523,7 +524,7 @@ class DINOTransformer(nn.Layer): ...@@ -523,7 +524,7 @@ class DINOTransformer(nn.Layer):
# prepare denoising training # prepare denoising training
if self.training: if self.training:
denoising_class, denoising_bbox, attn_mask, dn_meta = \ denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
get_contrastive_denoising_training_group(gt_meta, get_contrastive_denoising_training_group(gt_meta,
self.num_classes, self.num_classes,
self.num_queries, self.num_queries,
...@@ -532,18 +533,18 @@ class DINOTransformer(nn.Layer): ...@@ -532,18 +533,18 @@ class DINOTransformer(nn.Layer):
self.label_noise_ratio, self.label_noise_ratio,
self.box_noise_scale) self.box_noise_scale)
else: else:
denoising_class, denoising_bbox, attn_mask, dn_meta = None, None, None, None denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
target, init_ref_points, enc_topk_bboxes, enc_topk_logits = \ target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
self._get_decoder_input( self._get_decoder_input(
memory, spatial_shapes, mask_flatten, denoising_class, memory, spatial_shapes, mask_flatten, denoising_class,
denoising_bbox) denoising_bbox_unact)
# decoder # decoder
inter_feats, inter_ref_bboxes = self.decoder( inter_feats, inter_ref_bboxes_unact = self.decoder(
target, init_ref_points, memory, spatial_shapes, level_start_index, target, init_ref_points_unact, memory, spatial_shapes,
self.dec_bbox_head, self.query_pos_head, valid_ratios, attn_mask, level_start_index, self.dec_bbox_head, self.query_pos_head,
mask_flatten) valid_ratios, attn_mask, mask_flatten)
out_bboxes = [] out_bboxes = []
out_logits = [] out_logits = []
for i in range(self.num_decoder_layers): for i in range(self.num_decoder_layers):
...@@ -551,11 +552,11 @@ class DINOTransformer(nn.Layer): ...@@ -551,11 +552,11 @@ class DINOTransformer(nn.Layer):
if i == 0: if i == 0:
out_bboxes.append( out_bboxes.append(
F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) + F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) +
inverse_sigmoid(init_ref_points))) init_ref_points_unact))
else: else:
out_bboxes.append( out_bboxes.append(
F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) + F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) +
inverse_sigmoid(inter_ref_bboxes[i - 1]))) inter_ref_bboxes_unact[i - 1]))
out_bboxes = paddle.stack(out_bboxes) out_bboxes = paddle.stack(out_bboxes)
out_logits = paddle.stack(out_logits) out_logits = paddle.stack(out_logits)
...@@ -611,7 +612,7 @@ class DINOTransformer(nn.Layer): ...@@ -611,7 +612,7 @@ class DINOTransformer(nn.Layer):
spatial_shapes, spatial_shapes,
memory_mask=None, memory_mask=None,
denoising_class=None, denoising_class=None,
denoising_bbox=None): denoising_bbox_unact=None):
bs, _, _ = memory.shape bs, _, _ = memory.shape
# prepare input for decoder # prepare input for decoder
output_memory, output_anchors = self._get_encoder_output_anchors( output_memory, output_anchors = self._get_encoder_output_anchors(
...@@ -626,12 +627,12 @@ class DINOTransformer(nn.Layer): ...@@ -626,12 +627,12 @@ class DINOTransformer(nn.Layer):
batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype) batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype)
batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries]) batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries])
topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1) topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1)
topk_coords_unact = paddle.gather_nd(enc_outputs_coord_unact, reference_points_unact = paddle.gather_nd(enc_outputs_coord_unact,
topk_ind) # unsigmoided. topk_ind) # unsigmoided.
reference_points = enc_topk_bboxes = F.sigmoid(topk_coords_unact) enc_topk_bboxes = F.sigmoid(reference_points_unact)
if denoising_bbox is not None: if denoising_bbox_unact is not None:
reference_points = paddle.concat([denoising_bbox, enc_topk_bboxes], reference_points_unact = paddle.concat(
1) [denoising_bbox_unact, reference_points_unact], 1)
enc_topk_logits = paddle.gather_nd(enc_outputs_class, topk_ind) enc_topk_logits = paddle.gather_nd(enc_outputs_class, topk_ind)
# extract region features # extract region features
...@@ -642,5 +643,5 @@ class DINOTransformer(nn.Layer): ...@@ -642,5 +643,5 @@ class DINOTransformer(nn.Layer):
if denoising_class is not None: if denoising_class is not None:
target = paddle.concat([denoising_class, target], 1) target = paddle.concat([denoising_class, target], 1)
return target, reference_points.detach( return target, reference_points_unact.detach(
), enc_topk_bboxes, enc_topk_logits ), enc_topk_bboxes, enc_topk_logits
...@@ -194,7 +194,7 @@ def get_contrastive_denoising_training_group(targets, ...@@ -194,7 +194,7 @@ def get_contrastive_denoising_training_group(targets,
known_bbox += rand_part * diff known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0) known_bbox.clip_(min=0.0, max=1.0)
input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox) input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
input_query_bbox.clip_(min=0.0, max=1.0) input_query_bbox = inverse_sigmoid(input_query_bbox)
class_embed = paddle.concat( class_embed = paddle.concat(
[class_embed, paddle.zeros([1, class_embed.shape[-1]])]) [class_embed, paddle.zeros([1, class_embed.shape[-1]])])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册