提交 a4e50484 编写于 作者: J Jiageng Zhang 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 488859759
上级 aa5b35b3
......@@ -297,16 +297,18 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(
images,
image_info=labels['image_info'],
anchor_boxes=labels['anchor_boxes'],
gt_boxes=labels['gt_boxes'],
gt_outer_boxes=labels['gt_outer_boxes'],
gt_classes=labels['gt_classes'],
gt_masks=(labels['gt_masks'] if self.task_config.model.include_mask
else None),
training=True)
model_kwargs = {
'image_info': labels['image_info'],
'anchor_boxes': labels['anchor_boxes'],
'gt_boxes': labels['gt_boxes'],
'gt_classes': labels['gt_classes'],
'training': True,
}
if self.task_config.model.include_mask:
model_kwargs['gt_masks'] = labels['gt_masks']
if self.task_config.model.outer_boxes_scale > 1.0:
model_kwargs['gt_outer_boxes'] = labels['gt_outer_boxes']
outputs = model(images, **model_kwargs)
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
......
......@@ -363,19 +363,19 @@ class MaskRCNNTask(base_task.Task):
images, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
gt_outer_boxes, gt_masks = None, None
model_kwargs = {
'image_shape': labels['image_info'][:, 1, :],
'anchor_boxes': labels['anchor_boxes'],
'gt_boxes': labels['gt_boxes'],
'gt_classes': labels['gt_classes'],
'training': True,
}
if self.task_config.model.include_mask:
gt_outer_boxes = labels['gt_outer_boxes']
gt_masks = labels['gt_masks']
model_kwargs['gt_masks'] = labels['gt_masks']
if self.task_config.model.outer_boxes_scale > 1.0:
model_kwargs['gt_outer_boxes'] = labels['gt_outer_boxes']
outputs = model(
images,
image_shape=labels['image_info'][:, 1, :],
anchor_boxes=labels['anchor_boxes'],
gt_boxes=labels['gt_boxes'],
gt_outer_boxes=gt_outer_boxes,
gt_classes=labels['gt_classes'],
gt_masks=gt_masks,
training=True)
images, **model_kwargs)
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册