提交 284d7783 编写于 作者: A Abdullah Rashwan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 339275945
上级 05a95e16
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Image segmentation configuration definition.""" """Semantic segmentation configuration definition."""
import os import os
from typing import List, Union, Optional from typing import List, Union, Optional
import dataclasses import dataclasses
...@@ -50,8 +50,8 @@ class SegmentationHead(hyperparams.Config): ...@@ -50,8 +50,8 @@ class SegmentationHead(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class ImageSegmentationModel(hyperparams.Config): class SemanticSegmentationModel(hyperparams.Config):
"""Image segmentation model config.""" """Semantic segmentation model config."""
num_classes: int = 0 num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list) input_size: List[int] = dataclasses.field(default_factory=list)
min_level: int = 3 min_level: int = 3
...@@ -73,9 +73,9 @@ class Losses(hyperparams.Config): ...@@ -73,9 +73,9 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class ImageSegmentationTask(cfg.TaskConfig): class SemanticSegmentationTask(cfg.TaskConfig):
"""The model config.""" """The model config."""
model: ImageSegmentationModel = ImageSegmentationModel() model: SemanticSegmentationModel = SemanticSegmentationModel()
train_data: DataConfig = DataConfig(is_training=True) train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses() losses: Losses = Losses()
...@@ -89,7 +89,7 @@ class ImageSegmentationTask(cfg.TaskConfig): ...@@ -89,7 +89,7 @@ class ImageSegmentationTask(cfg.TaskConfig):
def semantic_segmentation() -> cfg.ExperimentConfig: def semantic_segmentation() -> cfg.ExperimentConfig:
"""Semantic segmentation general.""" """Semantic segmentation general."""
return cfg.ExperimentConfig( return cfg.ExperimentConfig(
task=ImageSegmentationModel(), task=SemanticSegmentationModel(),
trainer=cfg.TrainerConfig(), trainer=cfg.TrainerConfig(),
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
...@@ -109,8 +109,8 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: ...@@ -109,8 +109,8 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
eval_batch_size = 8 eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
task=ImageSegmentationTask( task=SemanticSegmentationTask(
model=ImageSegmentationModel( model=SemanticSegmentationModel(
num_classes=21, num_classes=21,
# TODO(arashwan): test changing size to 513 to match deeplab. # TODO(arashwan): test changing size to 513 to match deeplab.
input_size=[512, 512, 3], input_size=[512, 512, 3],
......
...@@ -31,9 +31,9 @@ class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -31,9 +31,9 @@ class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
def test_semantic_segmentation_configs(self, config_name): def test_semantic_segmentation_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig) self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.ImageSegmentationTask) self.assertIsInstance(config.task, exp_cfg.SemanticSegmentationTask)
self.assertIsInstance(config.task.model, self.assertIsInstance(config.task.model,
exp_cfg.ImageSegmentationModel) exp_cfg.SemanticSegmentationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig) self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
......
...@@ -241,7 +241,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec, ...@@ -241,7 +241,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
def build_segmentation_model( def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.ImageSegmentationModel, model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds Segmentation model.""" """Builds Segmentation model."""
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
......
...@@ -28,9 +28,9 @@ from official.vision.beta.losses import segmentation_losses ...@@ -28,9 +28,9 @@ from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
@task_factory.register_task_cls(exp_cfg.ImageSegmentationTask) @task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
class ImageSegmentationTask(base_task.Task): class SemanticSegmentationTask(base_task.Task):
"""A task for image classification.""" """A task for semantic classification."""
def build_model(self): def build_model(self):
"""Builds classification model.""" """Builds classification model."""
...@@ -219,8 +219,12 @@ class ImageSegmentationTask(base_task.Task): ...@@ -219,8 +219,12 @@ class ImageSegmentationTask(base_task.Task):
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses) if self.task_config.validation_data.resize_eval_groundtruth:
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
else:
loss = 0
logs = {self.loss: loss} logs = {self.loss: loss}
logs.update({self.miou_metric.name: (labels, outputs)}) logs.update({self.miou_metric.name: (labels, outputs)})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册