提交 14a9701d 编写于 作者: G Gunho Park

Use backbone factory

上级 94220a58
......@@ -15,11 +15,15 @@
"""DETR configurations."""
import dataclasses
import os
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.detr import optimization
import os
from official.vision.configs import common
from official.vision.configs import backbones
# pylint: disable=missing-class-docstring
......@@ -53,32 +57,41 @@ class DataConfig(cfg.DataConfig):
file_type: str = 'tfrecord'
@dataclasses.dataclass
class DetectionConfig(cfg.TaskConfig):
"""The translation task config."""
annotation_file: str = ''
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
class Losses(hyperparams.Config):
lambda_cls: float = 1.0
lambda_box: float = 5.0
lambda_giou: float = 2.0
background_cls_weight: float = 0.1
#init_ckpt: str = ''
init_checkpoint: str = 'gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet'
init_checkpoint_modules: str = 'backbone'
#num_classes: int = 81 # 0: background
@dataclasses.dataclass
class Detr(hyperparams.Config):
num_queries: int = 100
hidden_size: int = 256
num_classes: int = 91 # 0: background
background_cls_weight: float = 0.1
num_encoder_layers: int = 6
num_decoder_layers: int = 6
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(
model_id=101,
bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation()
# Make DETRConfig.
num_queries: int = 100
num_hidden: int = 256
@dataclasses.dataclass
class DetrTask(cfg.TaskConfig):
model: Detr = Detr()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone
annotation_file: Optional[str] = None
per_category_metrics: bool = False
COCO_INPUT_PATH_BASE = 'gs://ghpark-tfrecords/coco'
#COCO_TRAIN_EXAMPLES = 118287
COCO_TRAIN_EXAMPLES = 960
COCO_TRAIN_EXAMPLES = 118287
#COCO_TRAIN_EXAMPLES = 9600
COCO_VAL_EXAMPLES = 5000
@exp_factory.register_config_factory('detr_coco')
......@@ -91,9 +104,15 @@ def detr_coco() -> cfg.ExperimentConfig:
train_steps = 300 * steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetectionConfig(
task=DetrTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet101_imagenet',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=Detr(
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation(use_sync_bn=False)),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
......
......@@ -2,6 +2,6 @@
python3 train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=gs://ghpark-ckpts/detr/detr_coco/ckpt_03_test \
--model_dir=gs://ghpark-ckpts/detr/detr_coco/ckpt_03_detr_coco_resnet101 \
--tpu=postech-tpu \
--params_override=runtime.distribution_strategy='tpu'
\ No newline at end of file
......@@ -24,7 +24,7 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.projects.detr.modeling import transformer
from official.vision.modeling.backbones import resnet
#from official.vision.modeling.backbones import resnet
def position_embedding_sine(attention_mask,
......@@ -100,7 +100,7 @@ class DETR(tf.keras.Model):
class and box heads.
"""
def __init__(self, num_queries, hidden_size, num_classes,
def __init__(self, backbone, num_queries, hidden_size, num_classes,
num_encoder_layers=6,
num_decoder_layers=6,
dropout_rate=0.1,
......@@ -116,7 +116,9 @@ class DETR(tf.keras.Model):
raise ValueError("hidden_size must be a multiple of 2.")
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
self._backbone = resnet.ResNet(50, bn_trainable=False)
#self._backbone = resnet.ResNet(50, bn_trainable=False)
# (gunho) use backbone factory
self._backbone = backbone
def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D(
......
......@@ -31,8 +31,9 @@ from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.projects.detr.dataloaders import detr_input
from official.vision.modeling import backbones
@task_factory.register_task_cls(detr_cfg.DetectionConfig)
@task_factory.register_task_cls(detr_cfg.DetrTask)
class DectectionTask(base_task.Task):
"""A single-replica view of training procedure.
......@@ -43,12 +44,23 @@ class DectectionTask(base_task.Task):
def build_model(self):
"""Build DETR model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._task_config.model.input_size)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR(
self._task_config.num_queries,
self._task_config.num_hidden,
self._task_config.num_classes,
self._task_config.num_encoder_layers,
self._task_config.num_decoder_layers)
backbone,
self._task_config.model.num_queries,
self._task_config.model.hidden_size,
self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers,
self._task_config.model.num_decoder_layers)
return model
def initialize(self, model: tf.keras.Model):
......@@ -99,7 +111,9 @@ class DectectionTask(base_task.Task):
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = detr_input.Parser()
parser = detr_input.Parser(
output_size=self._task_config.model.input_size[:2],
)
reader = input_reader_factory.input_reader_generator(
params,
......@@ -114,24 +128,24 @@ class DectectionTask(base_task.Task):
# Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
cls_cost = self._task_config.lambda_cls * tf.gather(
cls_cost = self._task_config.losses.lambda_cls * tf.gather(
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
# Compute the L1 cost between boxes,
paired_differences = self._task_config.lambda_box * tf.abs(
paired_differences = self._task_config.losses.lambda_box * tf.abs(
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
box_cost = tf.reduce_sum(paired_differences, axis=-1)
# Compute the giou cost betwen boxes
giou_cost = self._task_config.lambda_giou * -box_ops.bbox_generalized_overlap(
giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_outputs),
box_ops.cycxhw_to_yxyx(box_targets))
total_cost = cls_cost + box_cost + giou_cost
max_cost = (
self._task_config.lambda_cls * 0.0 + self._task_config.lambda_box * 4. +
self._task_config.lambda_giou * 0.0)
self._task_config.losses.lambda_cls * 0.0 + self._task_config.losses.lambda_box * 4. +
self._task_config.losses.lambda_giou * 0.0)
# Set pads to large constant
valid = tf.expand_dims(
......@@ -170,20 +184,20 @@ class DectectionTask(base_task.Task):
# Down-weight background to account for class imbalance.
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=cls_targets, logits=cls_assigned)
cls_loss = self._task_config.lambda_cls * tf.where(
cls_loss = self._task_config.losses.lambda_cls * tf.where(
background,
self._task_config.background_cls_weight * xentropy,
self._task_config.losses.background_cls_weight * xentropy,
xentropy
)
cls_weights = tf.where(
background,
self._task_config.background_cls_weight * tf.ones_like(cls_loss),
self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
tf.ones_like(cls_loss)
)
# Box loss is only calculated on non-background class.
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
box_loss = self._task_config.lambda_box * tf.where(
box_loss = self._task_config.losses.lambda_box * tf.where(
background,
tf.zeros_like(l_1),
l_1
......@@ -194,7 +208,7 @@ class DectectionTask(base_task.Task):
box_ops.cycxhw_to_yxyx(box_assigned),
box_ops.cycxhw_to_yxyx(box_targets)
))
giou_loss = self._task_config.lambda_giou * tf.where(
giou_loss = self._task_config.losses.lambda_giou * tf.where(
background,
tf.zeros_like(giou),
giou
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册