提交 f7783e7a 编写于 作者: G Gunho Park

Use backbone factory

上级 14a9701d
......@@ -62,6 +62,7 @@ class Losses(hyperparams.Config):
lambda_box: float = 5.0
lambda_giou: float = 2.0
background_cls_weight: float = 0.1
l2_weight_decay: float = 1e-4
@dataclasses.dataclass
class Detr(hyperparams.Config):
......@@ -73,7 +74,7 @@ class Detr(hyperparams.Config):
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(
model_id=101,
model_id=50,
bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation()
......@@ -105,7 +106,7 @@ def detr_coco() -> cfg.ExperimentConfig:
decay_at = train_steps - 100 * steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet101_imagenet',
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
......
......@@ -24,7 +24,7 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.projects.detr.modeling import transformer
#from official.vision.modeling.backbones import resnet
from official.vision.modeling.backbones import resnet
def position_embedding_sine(attention_mask,
......@@ -116,7 +116,7 @@ class DETR(tf.keras.Model):
raise ValueError("hidden_size must be a multiple of 2.")
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
#self._backbone = resnet.ResNet(50, bn_trainable=False)
#self._backbone = resnet.ResNet(101, bn_trainable=False)
# (gunho) use backbone factory
self._backbone = backbone
......
......@@ -48,12 +48,17 @@ class DectectionTask(base_task.Task):
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR(
backbone,
self._task_config.model.num_queries,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册