From 2575940c37c44de0e162d6cb5ccc5dcf438fcb5d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Jul 2022 16:06:00 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 461738130 --- .../serving/export_saved_model.py | 31 ++++-- .../serving/panoptic_deeplab.py | 103 ++++++++++++++++++ ...c_segmentation.py => panoptic_maskrcnn.py} | 0 3 files changed, 126 insertions(+), 8 deletions(-) create mode 100644 official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_deeplab.py rename official/vision/beta/projects/panoptic_maskrcnn/serving/{panoptic_segmentation.py => panoptic_maskrcnn.py} (100%) diff --git a/official/vision/beta/projects/panoptic_maskrcnn/serving/export_saved_model.py b/official/vision/beta/projects/panoptic_maskrcnn/serving/export_saved_model.py index 2a0579e91..b95d8866b 100644 --- a/official/vision/beta/projects/panoptic_maskrcnn/serving/export_saved_model.py +++ b/official/vision/beta/projects/panoptic_maskrcnn/serving/export_saved_model.py @@ -39,14 +39,23 @@ import tensorflow as tf from official.core import exp_factory from official.modeling import hyperparams -from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as cfg # pylint: disable=unused-import +# pylint: disable=unused-import +from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg +from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg +# pylint: enable=unused-import from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory -from official.vision.beta.projects.panoptic_maskrcnn.serving import panoptic_segmentation -from official.vision.beta.projects.panoptic_maskrcnn.tasks import panoptic_maskrcnn as task # pylint: disable=unused-import +from official.vision.beta.projects.panoptic_maskrcnn.serving import panoptic_deeplab +from official.vision.beta.projects.panoptic_maskrcnn.serving import panoptic_maskrcnn +# pylint: disable=unused-import +from official.vision.beta.projects.panoptic_maskrcnn.tasks import panoptic_deeplab as panoptic_deeplab_task +from official.vision.beta.projects.panoptic_maskrcnn.tasks import panoptic_maskrcnn as panoptic_maskrcnn_task +# pylint: enable=unused-import from official.vision.serving import export_saved_model_lib FLAGS = flags.FLAGS +flags.DEFINE_string('model', 'panoptic_maskrcnn', + 'model type, one of panoptic_maskrcnn and panoptic_deeplab') flags.DEFINE_string('experiment', 'panoptic_fpn_coco', 'experiment type, e.g. panoptic_fpn_coco') flags.DEFINE_string('export_dir', None, 'The export directory.') @@ -89,16 +98,23 @@ def main(_): input_image_size = [int(x) for x in FLAGS.input_image_size.split(',')] input_specs = tf.keras.layers.InputSpec( shape=[FLAGS.batch_size, *input_image_size, 3]) - model = factory.build_panoptic_maskrcnn( - input_specs=input_specs, model_config=params.task.model) - export_module = panoptic_segmentation.PanopticSegmentationModule( + if FLAGS.model == 'panoptic_deeplab': + build_model = factory.build_panoptic_deeplab + panoptic_module = panoptic_deeplab.PanopticSegmentationModule + elif FLAGS.model == 'panoptic_maskrcnn': + build_model = factory.build_panoptic_maskrcnn + panoptic_module = panoptic_maskrcnn.PanopticSegmentationModule + else: + raise ValueError('Unsupported model type: %s' % FLAGS.model) + + model = build_model(input_specs=input_specs, model_config=params.task.model) + export_module = panoptic_module( params=params, model=model, batch_size=FLAGS.batch_size, input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')], num_channels=3) - export_saved_model_lib.export_inference_graph( input_type=FLAGS.input_type, batch_size=FLAGS.batch_size, @@ -110,6 +126,5 @@ def main(_): export_checkpoint_subdir='checkpoint', export_saved_model_subdir='saved_model') - if __name__ == '__main__': app.run(main) diff --git a/official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_deeplab.py b/official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_deeplab.py new file mode 100644 index 000000000..da5ce6cf4 --- /dev/null +++ b/official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_deeplab.py @@ -0,0 +1,103 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Panoptic Segmentation input and model functions for serving/inference.""" + +from typing import List + +import tensorflow as tf + +from official.core import config_definitions as cfg +from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory +from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model +from official.vision.serving import semantic_segmentation + + +class PanopticSegmentationModule( + semantic_segmentation.SegmentationModule): + """Panoptic Deeplab Segmentation Module.""" + + def __init__(self, + params: cfg.ExperimentConfig, + *, + model: tf.keras.Model, + batch_size: int, + input_image_size: List[int], + num_channels: int = 3): + """Initializes panoptic segmentation module for export.""" + + if batch_size is None: + raise ValueError('batch_size cannot be None for panoptic segmentation ' + 'model.') + if not isinstance(model, panoptic_deeplab_model.PanopticDeeplabModel): + raise ValueError('PanopticSegmentationModule module not ' + 'implemented for {} model.'.format(type(model))) + params.task.train_data.preserve_aspect_ratio = True + super(PanopticSegmentationModule, self).__init__( + params=params, + model=model, + batch_size=batch_size, + input_image_size=input_image_size, + num_channels=num_channels) + + def _build_model(self): + input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] + + self._input_image_size + [3]) + + return factory.build_panoptic_deeplab( + input_specs=input_specs, + model_config=self.params.task.model, + l2_regularizer=None) + + def serve(self, images: tf.Tensor): + """Cast image to float and run inference. + + Args: + images: uint8 Tensor of shape [batch_size, None, None, 3] + + Returns: + Tensor holding detection output logits. + """ + if self._input_type != 'tflite': + with tf.device('cpu:0'): + images = tf.cast(images, dtype=tf.float32) + images_spec = tf.TensorSpec( + shape=self._input_image_size + [3], dtype=tf.float32) + image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32) + + images, image_info = tf.nest.map_structure( + tf.identity, + tf.map_fn( + self._build_inputs, + elems=images, + fn_output_signature=(images_spec, image_info_spec), + parallel_iterations=32)) + + outputs = self.model.call( + inputs=images, image_info=image_info, training=False) + + masks = outputs['segmentation_outputs'] + masks = tf.image.resize(masks, self._input_image_size, method='bilinear') + classes = tf.math.argmax(masks, axis=-1) + scores = tf.nn.softmax(masks, axis=-1) + final_outputs = { + 'semantic_logits': masks, + 'semantic_scores': scores, + 'semantic_classes': classes, + 'image_info': image_info, + 'panoptic_category_mask': outputs['category_mask'], + 'panoptic_instance_mask': outputs['instance_mask'], + } + + return final_outputs diff --git a/official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_segmentation.py b/official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_maskrcnn.py similarity index 100% rename from official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_segmentation.py rename to official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_maskrcnn.py -- GitLab