提交 457bcb85 编写于 作者: V Vighnesh Birodkar 提交者: TF Object Detection Team

Refactor DeepMAC to process full batch and return mask logits with predict()

PiperOrigin-RevId: 426181961
上级 c3f2134b
......@@ -1050,6 +1050,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
else:
raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}')
heatmap = tf.stop_gradient(heatmap)
heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch.
......
......@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 24
// Next ID 25
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -485,6 +485,14 @@ message CenterNet {
optional int32 color_consistency_warmup_start = 23 [default=0];
// DeepMAC has been refactored to process the entire batch at once,
// instead of the previous (simple) approach of processing one sample at
// a time. Because of this, the memory consumption has increased and
// it's crucial to only feed the mask head the last stage outputs
// from the hourglass. Doing so halves the memory requirement of the
// mask head and does not cause a drop in evaluation metrics.
optional bool use_only_last_stage = 24 [default=false];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册