提交 ca8c44d5 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 424422082
上级 a7894f9e
# End-to-End Object Detection with Transformers (DETR)
[![DETR](https://img.shields.io/badge/DETR-arXiv.2005.12872-B3181B?)](https://arxiv.org/abs/2005.12872).
TensorFlow 2 implementation of End-to-End Object Detection with Transformers
⚠️ Disclaimer: All datasets hyperlinked from this page are not owned or
distributed by Google. The dataset is made available by third parties.
Please review the terms and conditions made available by the third parties
before using the data.
## Scripts:
You can find the scripts to reproduce the following experiments in
detr/experiments.
## DETR [COCO](https://cocodataset.org) ([ImageNet](https://www.image-net.org) pretrained)
| Model | Resolution | Batch size | Epochs | Decay@ | Params (M) | Box AP | Dashboard | Checkpoint | Experiment |
| --------- | :--------: | ----------:| ------:| -----: | ---------: | -----: | --------: | ---------: | ---------: |
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 | [tensorboard](https://tensorboard.dev/experiment/o2IEZnniRYu6pqViBeopIg/#scalars) | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_300.tar.gz) | detr_r50_300epochs.sh |
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0| [tensorboard](https://tensorboard.dev/experiment/YFMDKpESR4yjocPh5HgfRw/) | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_500.tar.gz) | detr_r50_500epochs.sh |
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 | paper | NA | NA |
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0 | paper | NA | NA |
| DETR-DC5-ResNet-50 | 1333x1333 |64|500| 400 |41 | 43.3 | paper | NA | NA |
## Need contribution:
* Add DC5 support and update experiment table.
## Citing TensorFlow Model Garden
If you find this codebase helpful in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2020,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DETR configurations."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.detr import optimization
from official.projects.detr.dataloaders import coco
@dataclasses.dataclass
class DetectionConfig(cfg.TaskConfig):
"""The translation task config."""
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
lambda_cls: float = 1.0
lambda_box: float = 5.0
lambda_giou: float = 2.0
init_ckpt: str = ''
num_classes: int = 81 # 0: background
background_cls_weight: float = 0.1
num_encoder_layers: int = 6
num_decoder_layers: int = 6
# Make DETRConfig.
num_queries: int = 100
num_hidden: int = 256
per_category_metrics: bool = False
@exp_factory.register_config_factory('detr_coco')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
num_train_data = 118287
num_steps_per_epoch = num_train_data // train_batch_size
train_steps = 500 * num_steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetectionConfig(
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False
)
),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=-1,
steps_per_loop=10000,
summary_interval=10000,
checkpoint_interval=10000,
validation_interval=10000,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})
),
restrictions=[
'task.train_data.is_training != None',
])
return config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for detr."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.detr.configs import detr as exp_cfg
from official.projects.detr.dataloaders import coco
class DetrTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('detr_coco',))
def test_detr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetectionConfig)
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
import dataclasses
from typing import Optional, Tuple
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.vision.beta.ops import box_ops
from official.vision.beta.ops import preprocess_ops
@dataclasses.dataclass
class COCODataConfig(cfg.DataConfig):
"""Data config for COCO."""
output_size: Tuple[int, int] = (1333, 1333)
max_num_boxes: int = 100
resize_scales: Tuple[int, ...] = (
480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
class COCODataLoader():
"""A class to load dataset for COCO detection task."""
def __init__(self, params: COCODataConfig):
self._params = params
def preprocess(self, inputs):
"""Preprocess COCO for DETR."""
image = inputs['image']
boxes = inputs['objects']['bbox']
classes = inputs['objects']['label'] + 1
is_crowd = inputs['objects']['is_crowd']
image = preprocess_ops.normalize_image(image)
if self._params.is_training:
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
if do_crop:
# Rescale
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side)
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Do croping
shape = tf.cast(image_info[1], dtype=tf.int32)
h = tf.random.uniform(
[], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
w = tf.random.uniform(
[], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
boxes = tf.clip_by_value(
(boxes[..., :] * tf.cast(
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
dtype=tf.float32) -
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
scales = tf.constant(
self._params.resize_scales,
dtype=tf.float32)
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
scales = tf.gather(scales, index, axis=0)
else:
scales = tf.constant([self._params.resize_scales[-1]], tf.float32)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(
image,
short_side,
max(self._params.output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(
image, 0, 0, self._params.output_size[0], self._params.output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(
classes, self._params.max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
boxes, self._params.max_num_boxes)
}
if not self._params.is_training:
labels.update({
'id':
inputs['image/id'],
'image_info':
image_info,
'is_crowd':
preprocess_ops.clip_or_pad_to_fixed_size(
is_crowd, self._params.max_num_boxes),
'gt_boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
gt_boxes, self._params.max_num_boxes),
})
return image, labels
def _transform_and_batch_fn(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
"""Preprocess and batch."""
dataset = dataset.map(
self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._params.global_batch_size
) if input_context else self._params.global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._params.is_training)
return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=None,
transform_and_batch_fn=self._transform_and_batch_fn)
return reader.read(input_context)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.dataloaders.coco."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from official.projects.detr.dataloaders import coco
def _gen_fn():
h = np.random.randint(0, 300)
w = np.random.randint(0, 300)
num_boxes = np.random.randint(0, 50)
return {
'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
'image/id': np.random.randint(0, 100),
'image/filename': 'test',
'objects': {
'is_crowd': np.ones(shape=(num_boxes), dtype=np.bool),
'bbox': np.ones(shape=(num_boxes, 4), dtype=np.float32),
'label': np.ones(shape=(num_boxes), dtype=np.int64),
'id': np.ones(shape=(num_boxes), dtype=np.int64),
'area': np.ones(shape=(num_boxes), dtype=np.int64),
}
}
class CocoDataloaderTest(tf.test.TestCase, parameterized.TestCase):
def test_load_dataset(self):
output_size = 1280
max_num_boxes = 100
batch_size = 2
data_config = coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=batch_size,
output_size=(output_size, output_size),
max_num_boxes=max_num_boxes,
)
num_examples = 10
def as_dataset(self, *args, **kwargs):
del args
del kwargs
return tf.data.Dataset.from_generator(
lambda: (_gen_fn() for i in range(num_examples)),
output_types=self.info.features.dtype,
output_shapes=self.info.features.shape,
)
with tfds.testing.mock_data(num_examples=num_examples,
as_dataset_fn=as_dataset):
dataset = coco.COCODataLoader(data_config).load()
dataset_iter = iter(dataset)
images, labels = next(dataset_iter)
self.assertEqual(images.shape, (batch_size, output_size, output_size, 3))
self.assertEqual(labels['classes'].shape, (batch_size, max_num_boxes))
self.assertEqual(labels['boxes'].shape, (batch_size, max_num_boxes, 4))
self.assertEqual(labels['id'].shape, (batch_size,))
self.assertEqual(
labels['image_info'].shape, (batch_size, 4, 2))
self.assertEqual(labels['is_crowd'].shape, (batch_size, max_num_boxes))
@parameterized.named_parameters(
('training', True),
('validation', False))
def test_preprocess(self, is_training):
output_size = 1280
max_num_boxes = 100
batch_size = 2
data_config = coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=is_training,
global_batch_size=batch_size,
output_size=(output_size, output_size),
max_num_boxes=max_num_boxes,
)
dl = coco.COCODataLoader(data_config)
inputs = _gen_fn()
image, label = dl.preprocess(inputs)
self.assertEqual(image.shape, (output_size, output_size, 3))
self.assertEqual(label['classes'].shape, (max_num_boxes))
self.assertEqual(label['boxes'].shape, (max_num_boxes, 4))
if not is_training:
self.assertDTypeEqual(label['id'], int)
self.assertEqual(
label['image_info'].shape, (4, 2))
self.assertEqual(label['is_crowd'].shape, (max_num_boxes))
if __name__ == '__main__':
tf.test.main()
#!/bin/bash
python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400
#!/bin/bash
python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements End-to-End Object Detection with Transformers.
Model paper: https://arxiv.org/abs/2005.12872
This module does not support Keras de/serialization. Please use
tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
for graph serializaiton.
"""
import math
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.detr.modeling import transformer
from official.vision.beta.modeling.backbones import resnet
def position_embedding_sine(attention_mask,
num_pos_features=256,
temperature=10000.,
normalize=True,
scale=2 * math.pi):
"""Sine-based positional embeddings for 2D images.
Args:
attention_mask: a `bool` Tensor specifying the size of the input image to
the Transformer and which elements are padded, of size [batch_size,
height, width]
num_pos_features: a `int` specifying the number of positional features,
should be equal to the hidden size of the Transformer network
temperature: a `float` specifying the temperature of the positional
embedding. Any type that is converted to a `float` can also be accepted.
normalize: a `bool` determining whether the positional embeddings should be
normalized between [0, scale] before application of the sine and cos
functions.
scale: a `float` if normalize is True specifying the scale embeddings before
application of the embedding function.
Returns:
embeddings: a `float` tensor of the same shape as input_tensor specifying
the positional embeddings based on sine features.
"""
if num_pos_features % 2 != 0:
raise ValueError(
"Number of embedding features (num_pos_features) must be even when "
"column and row embeddings are concatenated.")
num_pos_features = num_pos_features // 2
# Produce row and column embeddings based on total size of the image
# <tf.float>[batch_size, height, width]
attention_mask = tf.cast(attention_mask, tf.float32)
row_embedding = tf.cumsum(attention_mask, 1)
col_embedding = tf.cumsum(attention_mask, 2)
if normalize:
eps = 1e-6
row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale
col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale
dim_t = tf.range(num_pos_features, dtype=row_embedding.dtype)
dim_t = tf.pow(temperature, 2 * (dim_t // 2) / num_pos_features)
# Creates positional embeddings for each row and column position
# <tf.float>[batch_size, height, width, num_pos_features]
pos_row = tf.expand_dims(row_embedding, -1) / dim_t
pos_col = tf.expand_dims(col_embedding, -1) / dim_t
pos_row = tf.stack(
[tf.sin(pos_row[:, :, :, 0::2]),
tf.cos(pos_row[:, :, :, 1::2])], axis=4)
pos_col = tf.stack(
[tf.sin(pos_col[:, :, :, 0::2]),
tf.cos(pos_col[:, :, :, 1::2])], axis=4)
# final_shape = pos_row.shape.as_list()[:3] + [-1]
final_shape = tf_utils.get_shape_list(pos_row)[:3] + [-1]
pos_row = tf.reshape(pos_row, final_shape)
pos_col = tf.reshape(pos_col, final_shape)
output = tf.concat([pos_row, pos_col], -1)
embeddings = tf.cast(output, tf.float32)
return embeddings
class DETR(tf.keras.Model):
"""DETR model with Keras.
DETR consists of backbone, query embedding, DETRTransformer,
class and box heads.
"""
def __init__(self, num_queries, hidden_size, num_classes,
num_encoder_layers=6,
num_decoder_layers=6,
dropout_rate=0.1,
**kwargs):
super().__init__(**kwargs)
self._num_queries = num_queries
self._hidden_size = hidden_size
self._num_classes = num_classes
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
self._dropout_rate = dropout_rate
if hidden_size % 2 != 0:
raise ValueError("hidden_size must be a multiple of 2.")
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
self._backbone = resnet.ResNet(50, bn_trainable=False)
def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D(
self._hidden_size, 1, name="detr/conv2d")
self._transformer = DETRTransformer(
num_encoder_layers=self._num_encoder_layers,
num_decoder_layers=self._num_decoder_layers,
dropout_rate=self._dropout_rate)
self._query_embeddings = self.add_weight(
"detr/query_embeddings",
shape=[self._num_queries, self._hidden_size],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.),
dtype=tf.float32)
sqrt_k = math.sqrt(1.0 / self._hidden_size)
self._class_embed = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=tf.keras.initializers.RandomUniform(-sqrt_k, sqrt_k),
name="detr/cls_dense")
self._bbox_embed = [
tf.keras.layers.Dense(
self._hidden_size, activation="relu",
kernel_initializer=tf.keras.initializers.RandomUniform(
-sqrt_k, sqrt_k),
name="detr/box_dense_0"),
tf.keras.layers.Dense(
self._hidden_size, activation="relu",
kernel_initializer=tf.keras.initializers.RandomUniform(
-sqrt_k, sqrt_k),
name="detr/box_dense_1"),
tf.keras.layers.Dense(
4, kernel_initializer=tf.keras.initializers.RandomUniform(
-sqrt_k, sqrt_k),
name="detr/box_dense_2")]
self._sigmoid = tf.keras.layers.Activation("sigmoid")
super().build(input_shape)
@property
def backbone(self) -> tf.keras.Model:
return self._backbone
def get_config(self):
return {
"num_queries": self._num_queries,
"hidden_size": self._hidden_size,
"num_classes": self._num_classes,
"num_encoder_layers": self._num_encoder_layers,
"num_decoder_layers": self._num_decoder_layers,
"dropout_rate": self._dropout_rate,
}
@classmethod
def from_config(cls, config):
return cls(**config)
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
mask = tf.expand_dims(
tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0), inputs.dtype),
axis=-1)
features = self._backbone(inputs)["5"]
shape = tf.shape(features)
mask = tf.image.resize(
mask, shape[1:3], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
pos_embed = position_embedding_sine(
mask[:, :, :, 0], num_pos_features=self._hidden_size)
pos_embed = tf.reshape(pos_embed, [batch_size, -1, self._hidden_size])
features = tf.reshape(
self._input_proj(features), [batch_size, -1, self._hidden_size])
mask = tf.reshape(mask, [batch_size, -1])
decoded_list = self._transformer({
"inputs":
features,
"targets":
tf.tile(
tf.expand_dims(self._query_embeddings, axis=0),
(batch_size, 1, 1)),
"pos_embed": pos_embed,
"mask": mask,
})
out_list = []
for decoded in decoded_list:
decoded = tf.stack(decoded)
output_class = self._class_embed(decoded)
box_out = decoded
for layer in self._bbox_embed:
box_out = layer(box_out)
output_coord = self._sigmoid(box_out)
out = {"cls_outputs": output_class, "box_outputs": output_coord}
out_list.append(out)
return out_list
class DETRTransformer(tf.keras.layers.Layer):
"""Encoder and Decoder of DETR."""
def __init__(self, num_encoder_layers=6, num_decoder_layers=6,
dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)
self._dropout_rate = dropout_rate
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
def build(self, input_shape=None):
self._encoder = transformer.TransformerEncoder(
attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate,
intermediate_dropout=self._dropout_rate,
norm_first=False,
num_layers=self._num_encoder_layers,
)
self._decoder = transformer.TransformerDecoder(
attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate,
intermediate_dropout=self._dropout_rate,
norm_first=False,
num_layers=self._num_decoder_layers)
super().build(input_shape)
def get_config(self):
return {
"num_encoder_layers": self._num_encoder_layers,
"num_decoder_layers": self._num_decoder_layers,
"dropout_rate": self._dropout_rate,
}
def call(self, inputs):
sources = inputs["inputs"]
targets = inputs["targets"]
pos_embed = inputs["pos_embed"]
mask = inputs["mask"]
input_shape = tf_utils.get_shape_list(sources)
source_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, input_shape[1], 1])
memory = self._encoder(
sources, attention_mask=source_attention_mask, pos_embed=pos_embed)
target_shape = tf_utils.get_shape_list(targets)
cross_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, target_shape[1], 1])
target_shape = tf.shape(targets)
decoded = self._decoder(
tf.zeros_like(targets),
memory,
# TODO(b/199545430): self_attention_mask could be set to None when this
# bug is resolved. Passing ones for now.
self_attention_mask=tf.ones(
(target_shape[0], target_shape[1], target_shape[1])),
cross_attention_mask=cross_attention_mask,
return_all_decoder_outputs=True,
input_pos_embed=targets,
memory_pos_embed=pos_embed)
return decoded
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.detr."""
import tensorflow as tf
from official.projects.detr.modeling import detr
class DetrTest(tf.test.TestCase):
def test_forward(self):
num_queries = 10
hidden_size = 128
num_classes = 10
image_size = 640
batch_size = 2
model = detr.DETR(num_queries, hidden_size, num_classes)
outs = model(tf.ones((batch_size, image_size, image_size, 3)))
self.assertLen(outs, 6) # intermediate decoded outputs.
for out in outs:
self.assertAllEqual(
tf.shape(out['cls_outputs']), (batch_size, num_queries, num_classes))
self.assertAllEqual(
tf.shape(out['box_outputs']), (batch_size, num_queries, 4))
def test_get_from_config_detr_transformer(self):
config = {
'num_encoder_layers': 1,
'num_decoder_layers': 2,
'dropout_rate': 0.5,
}
detr_model = detr.DETRTransformer.from_config(config)
retrieved_config = detr_model.get_config()
self.assertEqual(config, retrieved_config)
def test_get_from_config_detr(self):
config = {
'num_queries': 2,
'hidden_size': 4,
'num_classes': 10,
'num_encoder_layers': 4,
'num_decoder_layers': 5,
'dropout_rate': 0.5,
}
detr_model = detr.DETR.from_config(config)
retrieved_config = detr_model.get_config()
self.assertEqual(config, retrieved_config)
if __name__ == '__main__':
tf.test.main()
此差异已折叠。
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transformer."""
import tensorflow as tf
from official.projects.detr.modeling import transformer
class TransformerTest(tf.test.TestCase):
def test_transformer_encoder_block(self):
batch_size = 2
sequence_length = 100
feature_size = 256
num_attention_heads = 2
inner_dim = 256
inner_activation = 'relu'
model = transformer.TransformerEncoderBlock(num_attention_heads, inner_dim,
inner_activation)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, sequence_length),
dtype=tf.int64)
pos_embed = tf.ones((batch_size, sequence_length, feature_size))
out = model([input_tensor, attention_mask, pos_embed])
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_encoder_block_get_config(self):
num_attention_heads = 2
inner_dim = 256
inner_activation = 'relu'
model = transformer.TransformerEncoderBlock(num_attention_heads, inner_dim,
inner_activation)
config = model.get_config()
expected_config = {
'name': 'transformer_encoder_block',
'trainable': True,
'dtype': 'float32',
'num_attention_heads': 2,
'inner_dim': 256,
'inner_activation': 'relu',
'output_dropout': 0.0,
'attention_dropout': 0.0,
'output_range': None,
'kernel_initializer': {
'class_name': 'GlorotUniform',
'config': {
'seed': None}
},
'bias_initializer': {
'class_name': 'Zeros',
'config': {}
},
'kernel_regularizer': None,
'bias_regularizer': None,
'activity_regularizer': None,
'kernel_constraint': None,
'bias_constraint': None,
'use_bias': True,
'norm_first': False,
'norm_epsilon': 1e-12,
'inner_dropout': 0.0,
'attention_initializer': {
'class_name': 'GlorotUniform',
'config': {'seed': None}
},
'attention_axes': None}
self.assertAllEqual(expected_config, config)
def test_transformer_encoder(self):
batch_size = 2
sequence_length = 100
feature_size = 256
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerEncoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, sequence_length),
dtype=tf.int64)
pos_embed = tf.ones((batch_size, sequence_length, feature_size))
out = model(input_tensor, attention_mask, pos_embed)
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_encoder_get_config(self):
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerEncoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
config = model.get_config()
expected_config = {
'name': 'transformer_encoder',
'trainable': True,
'dtype': 'float32',
'num_layers': 2,
'num_attention_heads': 2,
'intermediate_size': 256,
'activation': 'relu',
'dropout_rate': 0.0,
'attention_dropout_rate': 0.0,
'use_bias': False,
'norm_first': True,
'norm_epsilon': 1e-06,
'intermediate_dropout': 0.0
}
self.assertAllEqual(expected_config, config)
def test_transformer_decoder_block(self):
batch_size = 2
sequence_length = 100
memory_length = 200
feature_size = 256
num_attention_heads = 2
intermediate_size = 256
intermediate_activation = 'relu'
model = transformer.TransformerDecoderBlock(num_attention_heads,
intermediate_size,
intermediate_activation)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
memory = tf.ones((batch_size, memory_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, memory_length),
dtype=tf.int64)
self_attention_mask = tf.ones(
(batch_size, sequence_length, sequence_length), dtype=tf.int64)
input_pos_embed = tf.ones((batch_size, sequence_length, feature_size))
memory_pos_embed = tf.ones((batch_size, memory_length, feature_size))
out, _ = model([
input_tensor, memory, attention_mask, self_attention_mask,
input_pos_embed, memory_pos_embed
])
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_decoder_block_get_config(self):
num_attention_heads = 2
intermediate_size = 256
intermediate_activation = 'relu'
model = transformer.TransformerDecoderBlock(num_attention_heads,
intermediate_size,
intermediate_activation)
config = model.get_config()
expected_config = {
'name': 'transformer_decoder_block',
'trainable': True,
'dtype': 'float32',
'num_attention_heads': 2,
'intermediate_size': 256,
'intermediate_activation': 'relu',
'dropout_rate': 0.0,
'attention_dropout_rate': 0.0,
'kernel_initializer': {
'class_name': 'GlorotUniform',
'config': {
'seed': None
}
},
'bias_initializer': {
'class_name': 'Zeros',
'config': {}
},
'kernel_regularizer': None,
'bias_regularizer': None,
'activity_regularizer': None,
'kernel_constraint': None,
'bias_constraint': None,
'use_bias': True,
'norm_first': False,
'norm_epsilon': 1e-12,
'intermediate_dropout': 0.0,
'attention_initializer': {
'class_name': 'GlorotUniform',
'config': {
'seed': None
}
}
}
self.assertAllEqual(expected_config, config)
def test_transformer_decoder(self):
batch_size = 2
sequence_length = 100
memory_length = 200
feature_size = 256
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerDecoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
memory = tf.ones((batch_size, memory_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, memory_length),
dtype=tf.int64)
self_attention_mask = tf.ones(
(batch_size, sequence_length, sequence_length), dtype=tf.int64)
input_pos_embed = tf.ones((batch_size, sequence_length, feature_size))
memory_pos_embed = tf.ones((batch_size, memory_length, feature_size))
outs = model(
input_tensor,
memory,
self_attention_mask,
attention_mask,
return_all_decoder_outputs=True,
input_pos_embed=input_pos_embed,
memory_pos_embed=memory_pos_embed)
self.assertLen(outs, 2) # intermeidate decoded outputs.
for out in outs:
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_decoder_get_config(self):
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerDecoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
config = model.get_config()
expected_config = {
'name': 'transformer_decoder',
'trainable': True,
'dtype': 'float32',
'num_layers': 2,
'num_attention_heads': 2,
'intermediate_size': 256,
'activation': 'relu',
'dropout_rate': 0.0,
'attention_dropout_rate': 0.0,
'use_bias': False,
'norm_first': True,
'norm_epsilon': 1e-06,
'intermediate_dropout': 0.0
}
self.assertAllEqual(expected_config, config)
if __name__ == '__main__':
tf.test.main()
此差异已折叠。
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.ops.matchers."""
import numpy as np
from scipy import optimize
import tensorflow as tf
from official.projects.detr.ops import matchers
class MatchersOpsTest(tf.test.TestCase):
def testLinearSumAssignment(self):
"""Check a simple 2D test case of the Linear Sum Assignment problem.
Ensures that the implementation of the matching algorithm is correct
and functional on TPUs.
"""
cost_matrix = np.array([[[4, 1, 3], [2, 0, 5], [3, 2, 2]]],
dtype=np.float32)
_, adjacency_matrix = matchers.hungarian_matching(tf.constant(cost_matrix))
adjacency_output = adjacency_matrix.numpy()
correct_output = np.array([
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
], dtype=bool)
self.assertAllEqual(adjacency_output[0], correct_output)
def testBatchedLinearSumAssignment(self):
"""Check a batched case of the Linear Sum Assignment Problem.
Ensures that a correct solution is found for all inputted problems within
a batch.
"""
cost_matrix = np.array([
[[4, 1, 3], [2, 0, 5], [3, 2, 2]],
[[1, 4, 3], [0, 2, 5], [2, 3, 2]],
[[1, 3, 4], [0, 5, 2], [2, 2, 3]],
],
dtype=np.float32)
_, adjacency_matrix = matchers.hungarian_matching(tf.constant(cost_matrix))
adjacency_output = adjacency_matrix.numpy()
# Hand solved correct output for the linear sum assignment problem
correct_output = np.array([
[[0, 1, 0], [1, 0, 0], [0, 0, 1]],
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
[[1, 0, 0], [0, 0, 1], [0, 1, 0]],
],
dtype=bool)
self.assertAllClose(adjacency_output, correct_output)
def testMaximumBipartiteMatching(self):
"""Check that the maximum bipartite match assigns the correct numbers."""
adj_matrix = tf.cast([[
[1, 0, 0, 0, 1],
[0, 1, 0, 1, 0],
[0, 0, 1, 0, 0],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
]], tf.bool)
_, assignment = matchers._maximum_bipartite_matching(adj_matrix)
self.assertEqual(np.sum(assignment.numpy()), 5)
def testAssignmentMatchesScipy(self):
"""Check that the Linear Sum Assignment matches the Scipy implementation."""
batch_size, num_elems = 2, 25
weights = tf.random.uniform((batch_size, num_elems, num_elems),
minval=0.,
maxval=1.)
weights, assignment = matchers.hungarian_matching(weights)
for idx in range(batch_size):
_, scipy_assignment = optimize.linear_sum_assignment(weights.numpy()[idx])
hungarian_assignment = np.where(assignment.numpy()[idx])[1]
self.assertAllEqual(hungarian_assignment, scipy_assignment)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized optimizer to match paper results."""
import dataclasses
import tensorflow as tf
from official.modeling import optimization
from official.nlp import optimization as nlp_optimization
@dataclasses.dataclass
class DETRAdamWConfig(optimization.AdamWeightDecayConfig):
pass
@dataclasses.dataclass
class OptimizerConfig(optimization.OptimizerConfig):
detr_adamw: DETRAdamWConfig = DETRAdamWConfig()
@dataclasses.dataclass
class OptimizationConfig(optimization.OptimizationConfig):
"""Configuration for optimizer and learning rate schedule.
Attributes:
optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified, ema
optimizer will be used.
learning_rate: learning rate oneof config.
warmup: warmup oneof config.
"""
optimizer: OptimizerConfig = OptimizerConfig()
# TODO(frederickliu): figure out how to make this configuable.
# TODO(frederickliu): Study if this is needed.
class _DETRAdamW(nlp_optimization.AdamWeightDecay):
"""Custom AdamW to support different lr scaling for backbone.
The code is copied from AdamWeightDecay and Adam with learning scaling.
"""
def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
apply_state = kwargs['apply_state']
if 'detr' not in var.name:
lr_t *= 0.1
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
m = self.get_slot(var, 'm')
v = self.get_slot(var, 'v')
lr = coefficients[
'lr_t'] * 0.1 if 'detr' not in var.name else coefficients['lr_t']
if not self.amsgrad:
return tf.raw_ops.ResourceApplyAdam(
var=var.handle,
m=m.handle,
v=v.handle,
beta1_power=coefficients['beta_1_power'],
beta2_power=coefficients['beta_2_power'],
lr=lr,
beta1=coefficients['beta_1_t'],
beta2=coefficients['beta_2_t'],
epsilon=coefficients['epsilon'],
grad=grad,
use_locking=self._use_locking)
else:
vhat = self.get_slot(var, 'vhat')
return tf.raw_ops.ResourceApplyAdamWithAmsgrad(
var=var.handle,
m=m.handle,
v=v.handle,
vhat=vhat.handle,
beta1_power=coefficients['beta_1_power'],
beta2_power=coefficients['beta_2_power'],
lr=lr,
beta1=coefficients['beta_1_t'],
beta2=coefficients['beta_2_t'],
epsilon=coefficients['epsilon'],
grad=grad,
use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
apply_state = kwargs['apply_state']
if 'detr' not in var.name:
lr_t *= 0.1
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = tf.compat.v1.assign(m, m * coefficients['beta_1_t'],
use_locking=self._use_locking)
with tf.control_dependencies([m_t]):
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
v_t = tf.compat.v1.assign(v, v * coefficients['beta_2_t'],
use_locking=self._use_locking)
with tf.control_dependencies([v_t]):
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
lr = coefficients[
'lr_t'] * 0.1 if 'detr' not in var.name else coefficients['lr_t']
if not self.amsgrad:
v_sqrt = tf.sqrt(v_t)
var_update = tf.compat.v1.assign_sub(
var, lr * m_t / (v_sqrt + coefficients['epsilon']),
use_locking=self._use_locking)
return tf.group(*[var_update, m_t, v_t])
else:
v_hat = self.get_slot(var, 'vhat')
v_hat_t = tf.maximum(v_hat, v_t)
with tf.control_dependencies([v_hat_t]):
v_hat_t = tf.compat.v1.assign(
v_hat, v_hat_t, use_locking=self._use_locking)
v_hat_sqrt = tf.sqrt(v_hat_t)
var_update = tf.compat.v1.assign_sub(
var,
lr* m_t / (v_hat_sqrt + coefficients['epsilon']),
use_locking=self._use_locking)
return tf.group(*[var_update, m_t, v_t, v_hat_t])
optimization.register_optimizer_cls('detr_adamw', _DETRAdamW)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DETR detection task definition."""
import tensorflow as tf
from official.core import base_task
from official.core import task_factory
from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco
from official.projects.detr.modeling import detr
from official.projects.detr.ops import matchers
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.ops import box_ops
@task_factory.register_task_cls(detr_cfg.DetectionConfig)
class DectectionTask(base_task.Task):
"""A single-replica view of training procedure.
DETR task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def build_model(self):
"""Build DETR model."""
model = detr.DETR(
self._task_config.num_queries,
self._task_config.num_hidden,
self._task_config.num_classes,
self._task_config.num_encoder_layers,
self._task_config.num_decoder_layers)
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(self._task_config.init_ckpt)
status.expect_partial().assert_existing_objects_matched()
def build_inputs(self, params, input_context=None):
"""Build input dataset."""
return coco.COCODataLoader(params).load(input_context)
def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
# Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
cls_cost = self._task_config.lambda_cls * tf.gather(
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
# Compute the L1 cost between boxes,
paired_differences = self._task_config.lambda_box * tf.abs(
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
box_cost = tf.reduce_sum(paired_differences, axis=-1)
# Compute the giou cost betwen boxes
giou_cost = self._task_config.lambda_giou * -box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_outputs),
box_ops.cycxhw_to_yxyx(box_targets))
total_cost = cls_cost + box_cost + giou_cost
max_cost = (
self._task_config.lambda_cls * 0.0 + self._task_config.lambda_box * 4. +
self._task_config.lambda_giou * 0.0)
# Set pads to large constant
valid = tf.expand_dims(
tf.cast(tf.not_equal(cls_targets, 0), dtype=total_cost.dtype), axis=1)
total_cost = (1 - valid) * max_cost + valid * total_cost
# Set inf of nan to large constant
total_cost = tf.where(
tf.logical_or(tf.math.is_nan(total_cost), tf.math.is_inf(total_cost)),
max_cost * tf.ones_like(total_cost, dtype=total_cost.dtype),
total_cost)
return total_cost
def build_losses(self, outputs, labels, aux_losses=None):
"""Build DETR losses."""
cls_outputs = outputs['cls_outputs']
box_outputs = outputs['box_outputs']
cls_targets = labels['classes']
box_targets = labels['boxes']
cost = self._compute_cost(
cls_outputs, box_outputs, cls_targets, box_targets)
_, indices = matchers.hungarian_matching(cost)
indices = tf.stop_gradient(indices)
target_index = tf.math.argmax(indices, axis=1)
cls_assigned = tf.gather(cls_outputs, target_index, batch_dims=1, axis=1)
box_assigned = tf.gather(box_outputs, target_index, batch_dims=1, axis=1)
background = tf.equal(cls_targets, 0)
num_boxes = tf.reduce_sum(
tf.cast(tf.logical_not(background), tf.float32), axis=-1)
# Down-weight background to account for class imbalance.
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=cls_targets, logits=cls_assigned)
cls_loss = self._task_config.lambda_cls * tf.where(
background,
self._task_config.background_cls_weight * xentropy,
xentropy
)
cls_weights = tf.where(
background,
self._task_config.background_cls_weight * tf.ones_like(cls_loss),
tf.ones_like(cls_loss)
)
# Box loss is only calculated on non-background class.
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
box_loss = self._task_config.lambda_box * tf.where(
background,
tf.zeros_like(l_1),
l_1
)
# Giou loss is only calculated on non-background class.
giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_assigned),
box_ops.cycxhw_to_yxyx(box_targets)
))
giou_loss = self._task_config.lambda_giou * tf.where(
background,
tf.zeros_like(giou),
giou
)
# Consider doing all reduce once in train_step to speed up.
num_boxes_per_replica = tf.reduce_sum(num_boxes)
cls_weights_per_replica = tf.reduce_sum(cls_weights)
replica_context = tf.distribute.get_replica_context()
num_boxes_sum, cls_weights_sum = replica_context.all_reduce(
tf.distribute.ReduceOp.SUM,
[num_boxes_per_replica, cls_weights_per_replica])
cls_loss = tf.math.divide_no_nan(
tf.reduce_sum(cls_loss), cls_weights_sum)
box_loss = tf.math.divide_no_nan(
tf.reduce_sum(box_loss), num_boxes_sum)
giou_loss = tf.math.divide_no_nan(
tf.reduce_sum(giou_loss), num_boxes_sum)
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
total_loss = cls_loss + box_loss + giou_loss + aux_losses
return total_loss, cls_loss, box_loss, giou_loss
def build_metrics(self, training=True):
"""Build detection metrics."""
metrics = []
metric_names = ['cls_loss', 'box_loss', 'giou_loss']
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
if not training:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file='',
include_mask=False,
need_rescale_bboxes=True,
per_category_metrics=self._task_config.per_category_metrics)
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
with tf.GradientTape() as tape:
outputs = model(features, training=True)
loss = 0.0
cls_loss = 0.0
box_loss = 0.0
giou_loss = 0.0
for output in outputs:
# Computes per-replica loss.
layer_loss, layer_cls_loss, layer_box_loss, layer_giou_loss = self.build_losses(
outputs=output, labels=labels, aux_losses=model.losses)
loss += layer_loss
cls_loss += layer_cls_loss
box_loss += layer_box_loss
giou_loss += layer_giou_loss
# Consider moving scaling logic from build_losses to here.
scaled_loss = loss
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient when LossScaleOptimizer is used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
# Multiply for logging.
# Since we expect the gradient replica sum to happen in the optimizer,
# the loss is scaled with global num_boxes and weights.
# To have it more interpretable/comparable we scale it back when logging.
num_replicas_in_sync = tf.distribute.get_strategy().num_replicas_in_sync
loss *= num_replicas_in_sync
cls_loss *= num_replicas_in_sync
box_loss *= num_replicas_in_sync
giou_loss *= num_replicas_in_sync
# Trainer class handles loss metric for you.
logs = {self.loss: loss}
all_losses = {
'cls_loss': cls_loss,
'box_loss': box_loss,
'giou_loss': giou_loss,
}
# Metric results will be added to logs for you.
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = model(features, training=False)[-1]
loss, cls_loss, box_loss, giou_loss = self.build_losses(
outputs=outputs, labels=labels, aux_losses=model.losses)
# Multiply for logging.
# Since we expect the gradient replica sum to happen in the optimizer,
# the loss is scaled with global num_boxes and weights.
# To have it more interpretable/comparable we scale it back when logging.
num_replicas_in_sync = tf.distribute.get_strategy().num_replicas_in_sync
loss *= num_replicas_in_sync
cls_loss *= num_replicas_in_sync
box_loss *= num_replicas_in_sync
giou_loss *= num_replicas_in_sync
# Evaluator class handles loss metric for you.
logs = {self.loss: loss}
predictions = {
'detection_boxes':
box_ops.cycxhw_to_yxyx(outputs['box_outputs'])
* tf.expand_dims(
tf.concat([
labels['image_info'][:, 1:2, 0],
labels['image_info'][:, 1:2, 1],
labels['image_info'][:, 1:2, 0],
labels['image_info'][:, 1:2, 1]
],
axis=1),
axis=1),
'detection_scores':
tf.math.reduce_max(
tf.nn.softmax(outputs['cls_outputs'])[:, :, 1:], axis=-1),
'detection_classes':
tf.math.argmax(outputs['cls_outputs'][:, :, 1:], axis=-1) + 1,
# Fix this. It's not being used at the moment.
'num_detections': tf.reduce_sum(
tf.cast(
tf.math.greater(tf.math.reduce_max(
outputs['cls_outputs'], axis=-1), 0), tf.int32), axis=-1),
'source_id': labels['id'],
'image_info': labels['image_info']
}
ground_truths = {
'source_id': labels['id'],
'height': labels['image_info'][:, 0:1, 0],
'width': labels['image_info'][:, 0:1, 1],
'num_detections': tf.reduce_sum(
tf.cast(tf.math.greater(labels['classes'], 0), tf.int32), axis=-1),
'boxes': labels['gt_boxes'],
'classes': labels['classes'],
'is_crowds': labels['is_crowd']
}
logs.update({'predictions': predictions,
'ground_truths': ground_truths})
all_losses = {
'cls_loss': cls_loss,
'box_loss': box_loss,
'giou_loss': giou_loss,
}
# Metric results will be added to logs for you.
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
return logs
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.coco_metric.reset_states()
state = self.coco_metric
state.update_state(
step_outputs['ground_truths'],
step_outputs['predictions'])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
return aggregated_logs.result()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for detection."""
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from official.projects.detr import optimization
from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco
from official.projects.detr.tasks import detection
_NUM_EXAMPLES = 10
def _gen_fn():
h = np.random.randint(0, 300)
w = np.random.randint(0, 300)
num_boxes = np.random.randint(0, 50)
return {
'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
'image/id': np.random.randint(0, 100),
'image/filename': 'test',
'objects': {
'is_crowd': np.ones(shape=(num_boxes), dtype=np.bool),
'bbox': np.ones(shape=(num_boxes, 4), dtype=np.float32),
'label': np.ones(shape=(num_boxes), dtype=np.int64),
'id': np.ones(shape=(num_boxes), dtype=np.int64),
'area': np.ones(shape=(num_boxes), dtype=np.int64),
}
}
def _as_dataset(self, *args, **kwargs):
del args
del kwargs
return tf.data.Dataset.from_generator(
lambda: (_gen_fn() for i in range(_NUM_EXAMPLES)),
output_types=self.info.features.dtype,
output_shapes=self.info.features.shape,
)
class DetectionTest(tf.test.TestCase):
def test_train_step(self):
config = detr_cfg.DetectionConfig(
num_encoder_layers=1,
num_decoder_layers=1,
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=True,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
opt_cfg = optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [120000],
'values': [0.0001, 1.0e-05]
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
config = detr_cfg.DetectionConfig(
num_encoder_layers=1,
num_decoder_layers=1,
validation_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
iterator = iter(dataset)
logs = task.validation_step(next(iterator), model, metrics)
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.projects.detr.configs import detr
from official.projects.detr.tasks import detection
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册