提交 aa2b5e42 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 373084407
上级 3c575a70
......@@ -67,6 +67,8 @@ class SpineNet(hyperparams.Config):
"""SpineNet config."""
model_id: str = '49'
stochastic_depth_drop_rate: float = 0.0
min_level: int = 3
max_level: int = 7
@dataclasses.dataclass
......@@ -76,6 +78,8 @@ class SpineNetMobile(hyperparams.Config):
stochastic_depth_drop_rate: float = 0.0
se_ratio: float = 0.2
expand_ratio: int = 6
min_level: int = 3
max_level: int = 7
@dataclasses.dataclass
......
......@@ -437,7 +437,12 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
'instances_val2017.json'),
model=MaskRCNN(
backbone=backbones.Backbone(
type='spinenet', spinenet=backbones.SpineNet(model_id='49')),
type='spinenet',
spinenet=backbones.SpineNet(
model_id='49',
min_level=3,
max_level=7,
)),
decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()),
anchor=Anchor(anchor_size=3),
......@@ -491,6 +496,8 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
'task.validation_data.is_training != None',
'task.model.min_level == task,model.backbone.spinenet.min_level',
'task.model.max_level == task,model.backbone.spinenet.max_level',
])
return config
......@@ -248,7 +248,10 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='spinenet',
spinenet=backbones.SpineNet(
model_id='49', stochastic_depth_drop_rate=0.2)),
model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7)),
decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()),
anchor=Anchor(anchor_size=3),
......@@ -306,7 +309,9 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
'task.validation_data.is_training != None',
'task.model.min_level == task,model.backbone.spinenet.min_level',
'task.model.max_level == task,model.backbone.spinenet.max_level',
])
return config
......@@ -329,7 +334,10 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='spinenet_mobile',
spinenet_mobile=backbones.SpineNetMobile(
model_id='49', stochastic_depth_drop_rate=0.2)),
model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7)),
decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()),
head=RetinaNetHead(num_filters=48, use_separable_conv=True),
......@@ -388,7 +396,9 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
'task.validation_data.is_training != None',
'task.model.min_level == task,model.backbone.spinenet_mobile.min_level',
'task.model.max_level == task,model.backbone.spinenet_mobile.max_level',
])
return config
......@@ -297,12 +297,12 @@ class EfficientNet(tf.keras.Model):
@factory.register_backbone_builder('efficientnet')
def build_efficientnet(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds EfficientNet backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'efficientnet', (f'Inconsistent backbone type '
f'{backbone_type}')
......
......@@ -42,6 +42,8 @@ in place that uses it.
"""
from typing import Sequence, Union
# Import libraries
import tensorflow as tf
......@@ -81,22 +83,31 @@ def register_backbone_builder(key: str):
return registry.register(_REGISTERED_BACKBONE_CLS, key)
def build_backbone(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
def build_backbone(input_specs: Union[tf.keras.layers.InputSpec,
Sequence[tf.keras.layers.InputSpec]],
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> tf.keras.Model:
"""Builds backbone from a config.
Args:
input_specs: A `tf.keras.layers.InputSpec` of input.
model_config: A `OneOfConfig` of model config.
input_specs: A (sequence of) `tf.keras.layers.InputSpec` of input.
backbone_config: A `OneOfConfig` of backbone config.
norm_activation_config: A config for normalization/activation layer.
l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None.
**kwargs: Additional keyword args to be passed to backbone builder.
Returns:
A `tf.keras.Model` instance of the backbone.
"""
backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS,
model_config.backbone.type)
return backbone_builder(input_specs, model_config, l2_regularizer)
backbone_config.type)
return backbone_builder(
input_specs=input_specs,
backbone_config=backbone_config,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer,
**kwargs)
......@@ -22,7 +22,6 @@ from tensorflow.python.distribute import combinations
from official.vision.beta.configs import backbones as backbones_cfg
from official.vision.beta.configs import backbones_3d as backbones_3d_cfg
from official.vision.beta.configs import common as common_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.backbones import factory
......@@ -42,12 +41,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
resnet=backbones_cfg.ResNet(model_id=model_id, se_ratio=0.0))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config)
backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
......@@ -74,12 +72,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
model_id=model_id, se_ratio=se_ratio))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config)
backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
......@@ -108,12 +105,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
model_id=model_id, filter_size_scale=filter_size_scale))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config)
backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
......@@ -141,13 +137,12 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
spinenet=backbones_cfg.SpineNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(
shape=[None, input_size, input_size, 3]),
model_config=model_config)
backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
......@@ -166,12 +161,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
revnet=backbones_cfg.RevNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config)
backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
......
......@@ -766,12 +766,12 @@ class MobileNet(tf.keras.Model):
@factory.register_backbone_builder('mobilenet')
def build_mobilenet(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds MobileNet backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'mobilenet', (f'Inconsistent backbone type '
f'{backbone_type}')
......
......@@ -372,12 +372,12 @@ class ResNet(tf.keras.Model):
@factory.register_backbone_builder('resnet')
def build_resnet(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'resnet', (f'Inconsistent backbone type '
f'{backbone_type}')
......
......@@ -378,11 +378,11 @@ class ResNet3D(tf.keras.Model):
@factory.register_backbone_builder('resnet_3d')
def build_resnet3d(
input_specs: tf.keras.layers.InputSpec,
model_config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_cfg = backbone_config.get()
# Flatten configs before passing to the backbone.
temporal_strides = []
......@@ -416,11 +416,11 @@ def build_resnet3d(
@factory.register_backbone_builder('resnet_3d_rs')
def build_resnet3d_rs(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet-3D-RS backbone from a config."""
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_cfg = backbone_config.get()
# Flatten configs before passing to the backbone.
temporal_strides = []
......
......@@ -18,6 +18,7 @@ from typing import Callable, Optional, Tuple, List
import numpy as np
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
......@@ -340,12 +341,12 @@ class DilatedResNet(tf.keras.Model):
@factory.register_backbone_builder('dilated_resnet')
def build_dilated_resnet(
input_specs: tf.keras.layers.InputSpec,
model_config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'dilated_resnet', (f'Inconsistent backbone type '
f'{backbone_type}')
......
......@@ -18,6 +18,7 @@
from typing import Any, Callable, Dict, Optional
# Import libraries
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
......@@ -213,12 +214,12 @@ class RevNet(tf.keras.Model):
@factory.register_backbone_builder('revnet')
def build_revnet(
input_specs: tf.keras.layers.InputSpec,
model_config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds RevNet backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'revnet', (f'Inconsistent backbone type '
f'{backbone_type}')
......
......@@ -22,6 +22,7 @@ from typing import Any, List, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
......@@ -527,12 +528,12 @@ class SpineNet(tf.keras.Model):
@factory.register_backbone_builder('spinenet')
def build_spinenet(
input_specs: tf.keras.layers.InputSpec,
model_config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds SpineNet backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'spinenet', (f'Inconsistent backbone type '
f'{backbone_type}')
......@@ -544,8 +545,8 @@ def build_spinenet(
return SpineNet(
input_specs=input_specs,
min_level=model_config.min_level,
max_level=model_config.max_level,
min_level=backbone_cfg.min_level,
max_level=backbone_cfg.max_level,
endpoints_num_filters=scaling_params['endpoints_num_filters'],
resample_alpha=scaling_params['resample_alpha'],
block_repeats=scaling_params['block_repeats'],
......
......@@ -36,6 +36,7 @@ from typing import Any, List, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
......@@ -501,12 +502,12 @@ class SpineNetMobile(tf.keras.Model):
@factory.register_backbone_builder('spinenet_mobile')
def build_spinenet_mobile(
input_specs: tf.keras.layers.InputSpec,
model_config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Mobile SpineNet backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'spinenet_mobile', (f'Inconsistent backbone type '
f'{backbone_type}')
......@@ -518,8 +519,8 @@ def build_spinenet_mobile(
return SpineNetMobile(
input_specs=input_specs,
min_level=model_config.min_level,
max_level=model_config.max_level,
min_level=backbone_cfg.min_level,
max_level=backbone_cfg.max_level,
endpoints_num_filters=scaling_params['endpoints_num_filters'],
block_repeats=scaling_params['block_repeats'],
filter_size_scale=scaling_params['filter_size_scale'],
......
......@@ -44,12 +44,13 @@ def build_classification_model(
l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False) -> tf.keras.Model:
"""Builds the classification model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
model = classification_model.ClassificationModel(
backbone=backbone,
num_classes=model_config.num_classes,
......@@ -69,9 +70,11 @@ def build_maskrcnn(
model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
......@@ -85,7 +88,6 @@ def build_maskrcnn(
roi_aligner_config = model_config.roi_aligner
detection_head_config = model_config.detection_head
generator_config = model_config.detection_generator
norm_activation_config = model_config.norm_activation
num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
......@@ -242,9 +244,11 @@ def build_retinanet(
model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds RetinaNet model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
......@@ -254,7 +258,6 @@ def build_retinanet(
head_config = model_config.head
generator_config = model_config.detection_generator
norm_activation_config = model_config.norm_activation
num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
......@@ -301,9 +304,11 @@ def build_segmentation_model(
model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Segmentation model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
......@@ -312,7 +317,6 @@ def build_segmentation_model(
l2_regularizer=l2_regularizer)
head_config = model_config.head
norm_activation_config = model_config.norm_activation
head = segmentation_heads.SegmentationHead(
num_classes=model_config.num_classes,
......
......@@ -85,9 +85,11 @@ def build_video_classification_model(
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds the video classification model."""
input_specs_dict = {'image': input_specs}
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
model = video_classification_model.VideoClassificationModel(
......
......@@ -54,6 +54,7 @@ from absl import logging
import numpy as np
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling import factory_3d as model_factory
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.projects.assemblenet.configs import assemblenet as cfg
......@@ -1015,14 +1016,14 @@ def assemblenet_v1(assemblenet_depth: int,
@backbone_factory.register_backbone_builder('assemblenet')
def build_assemblenet_v1(
input_specs: tf.keras.layers.InputSpec,
model_config: cfg.Backbone3D,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds assemblenet backbone."""
del l2_regularizer
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'assemblenet'
assemblenet_depth = int(backbone_cfg.model_id)
......@@ -1060,7 +1061,8 @@ def build_assemblenet_model(
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds assemblenet model."""
input_specs_dict = {'image': input_specs}
backbone = build_assemblenet_v1(input_specs, model_config, l2_regularizer)
backbone = build_assemblenet_v1(input_specs, model_config.backbone,
model_config.norm_activation, l2_regularizer)
backbone_cfg = model_config.backbone.get()
model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks)
model = AssembleNetModel(
......
......@@ -37,9 +37,11 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
......@@ -53,7 +55,6 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
roi_aligner_config = model_config.roi_aligner
detection_head_config = model_config.detection_head
generator_config = model_config.detection_generator
norm_activation_config = model_config.norm_activation
num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
......
......@@ -110,7 +110,8 @@ class SimCLRPretrainTask(base_task.Task):
# Build backbone
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer)
# Build projection head
......
......@@ -40,6 +40,7 @@ import collections
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
......@@ -428,12 +429,12 @@ class Darknet(tf.keras.Model):
@factory.register_backbone_builder("darknet")
def build_darknet(
input_specs: tf.keras.layers.InputSpec,
model_config,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds darknet backbone."""
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_cfg = backbone_config.get()
model = Darknet(
model_id=backbone_cfg.model_id,
input_shape=input_specs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册