Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
f7783e7a
M
Models
项目概览
曾经的那一瞬间
/
Models
12 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f7783e7a
编写于
6月 30, 2022
作者:
G
Gunho Park
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use backbone factory
上级
14a9701d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
11 addition
and
5 deletion
+11
-5
official/projects/detr/configs/detr.py
official/projects/detr/configs/detr.py
+3
-2
official/projects/detr/modeling/detr.py
official/projects/detr/modeling/detr.py
+2
-2
official/projects/detr/tasks/detection.py
official/projects/detr/tasks/detection.py
+6
-1
未找到文件。
official/projects/detr/configs/detr.py
浏览文件 @
f7783e7a
...
...
@@ -62,6 +62,7 @@ class Losses(hyperparams.Config):
lambda_box
:
float
=
5.0
lambda_giou
:
float
=
2.0
background_cls_weight
:
float
=
0.1
l2_weight_decay
:
float
=
1e-4
@
dataclasses
.
dataclass
class
Detr
(
hyperparams
.
Config
):
...
...
@@ -73,7 +74,7 @@ class Detr(hyperparams.Config):
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
101
,
model_id
=
50
,
bn_trainable
=
False
))
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
...
...
@@ -105,7 +106,7 @@ def detr_coco() -> cfg.ExperimentConfig:
decay_at
=
train_steps
-
100
*
steps_per_epoch
# 400 epochs
config
=
cfg
.
ExperimentConfig
(
task
=
DetrTask
(
init_checkpoint
=
'gs://ghpark-imagenet-tfrecord/ckpt/resnet
101
_imagenet'
,
init_checkpoint
=
'gs://ghpark-imagenet-tfrecord/ckpt/resnet
50
_imagenet'
,
init_checkpoint_modules
=
'backbone'
,
annotation_file
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'instances_val2017.json'
),
...
...
official/projects/detr/modeling/detr.py
浏览文件 @
f7783e7a
...
...
@@ -24,7 +24,7 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.projects.detr.modeling
import
transformer
#
from official.vision.modeling.backbones import resnet
from
official.vision.modeling.backbones
import
resnet
def
position_embedding_sine
(
attention_mask
,
...
...
@@ -116,7 +116,7 @@ class DETR(tf.keras.Model):
raise
ValueError
(
"hidden_size must be a multiple of 2."
)
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
#self._backbone = resnet.ResNet(
50
, bn_trainable=False)
#self._backbone = resnet.ResNet(
101
, bn_trainable=False)
# (gunho) use backbone factory
self
.
_backbone
=
backbone
...
...
official/projects/detr/tasks/detection.py
浏览文件 @
f7783e7a
...
...
@@ -48,12 +48,17 @@ class DectectionTask(base_task.Task):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
_task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
self
.
_task_config
.
model
.
backbone
,
norm_activation_config
=
self
.
_task_config
.
model
.
norm_activation
)
model
=
detr
.
DETR
(
backbone
,
self
.
_task_config
.
model
.
num_queries
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录