提交 284d7783 编写于 作者: A Abdullah Rashwan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 339275945
上级 05a95e16
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image segmentation configuration definition."""
"""Semantic segmentation configuration definition."""
import os
from typing import List, Union, Optional
import dataclasses
......@@ -50,8 +50,8 @@ class SegmentationHead(hyperparams.Config):
@dataclasses.dataclass
class ImageSegmentationModel(hyperparams.Config):
"""Image segmentation model config."""
class SemanticSegmentationModel(hyperparams.Config):
"""Semantic segmentation model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
min_level: int = 3
......@@ -73,9 +73,9 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass
class ImageSegmentationTask(cfg.TaskConfig):
class SemanticSegmentationTask(cfg.TaskConfig):
"""The model config."""
model: ImageSegmentationModel = ImageSegmentationModel()
model: SemanticSegmentationModel = SemanticSegmentationModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
......@@ -89,7 +89,7 @@ class ImageSegmentationTask(cfg.TaskConfig):
def semantic_segmentation() -> cfg.ExperimentConfig:
"""Semantic segmentation general."""
return cfg.ExperimentConfig(
task=ImageSegmentationModel(),
task=SemanticSegmentationModel(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
......@@ -109,8 +109,8 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageSegmentationTask(
model=ImageSegmentationModel(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
# TODO(arashwan): test changing size to 513 to match deeplab.
input_size=[512, 512, 3],
......
......@@ -31,9 +31,9 @@ class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
def test_semantic_segmentation_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.ImageSegmentationTask)
self.assertIsInstance(config.task, exp_cfg.SemanticSegmentationTask)
self.assertIsInstance(config.task.model,
exp_cfg.ImageSegmentationModel)
exp_cfg.SemanticSegmentationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
......
......@@ -241,7 +241,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.ImageSegmentationModel,
model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds Segmentation model."""
backbone = backbones.factory.build_backbone(
......
......@@ -28,9 +28,9 @@ from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory
@task_factory.register_task_cls(exp_cfg.ImageSegmentationTask)
class ImageSegmentationTask(base_task.Task):
"""A task for image classification."""
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
class SemanticSegmentationTask(base_task.Task):
"""A task for semantic classification."""
def build_model(self):
"""Builds classification model."""
......@@ -219,8 +219,12 @@ class ImageSegmentationTask(base_task.Task):
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
if self.task_config.validation_data.resize_eval_groundtruth:
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
else:
loss = 0
logs = {self.loss: loss}
logs.update({self.miou_metric.name: (labels, outputs)})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册