Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
14a9701d
M
Models
项目概览
曾经的那一瞬间
/
Models
12 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
14a9701d
编写于
6月 28, 2022
作者:
G
Gunho Park
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use backbone factory
上级
94220a58
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
73 addition
and
38 deletion
+73
-38
official/projects/detr/configs/detr.py
official/projects/detr/configs/detr.py
+36
-17
official/projects/detr/do_train.sh
official/projects/detr/do_train.sh
+1
-1
official/projects/detr/modeling/detr.py
official/projects/detr/modeling/detr.py
+5
-3
official/projects/detr/tasks/detection.py
official/projects/detr/tasks/detection.py
+31
-17
未找到文件。
official/projects/detr/configs/detr.py
浏览文件 @
14a9701d
...
...
@@ -15,11 +15,15 @@
"""DETR configurations."""
import
dataclasses
import
os
from
typing
import
List
,
Optional
,
Union
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.projects.detr
import
optimization
import
os
from
official.vision.configs
import
common
from
official.vision.configs
import
backbones
# pylint: disable=missing-class-docstring
...
...
@@ -53,32 +57,41 @@ class DataConfig(cfg.DataConfig):
file_type
:
str
=
'tfrecord'
@
dataclasses
.
dataclass
class
DetectionConfig
(
cfg
.
TaskConfig
):
"""The translation task config."""
annotation_file
:
str
=
''
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
class
Losses
(
hyperparams
.
Config
):
lambda_cls
:
float
=
1.0
lambda_box
:
float
=
5.0
lambda_giou
:
float
=
2.0
background_cls_weight
:
float
=
0.1
#init_ckpt: str = ''
init_checkpoint
:
str
=
'gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet'
init_checkpoint_modules
:
str
=
'backbone'
#num_classes: int = 81 # 0: background
@
dataclasses
.
dataclass
class
Detr
(
hyperparams
.
Config
):
num_queries
:
int
=
100
hidden_size
:
int
=
256
num_classes
:
int
=
91
# 0: background
background_cls_weight
:
float
=
0.1
num_encoder_layers
:
int
=
6
num_decoder_layers
:
int
=
6
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
101
,
bn_trainable
=
False
))
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
# Make DETRConfig.
num_queries
:
int
=
100
num_hidden
:
int
=
256
@
dataclasses
.
dataclass
class
DetrTask
(
cfg
.
TaskConfig
):
model
:
Detr
=
Detr
()
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
losses
:
Losses
=
Losses
()
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'all'
# all, backbone
annotation_file
:
Optional
[
str
]
=
None
per_category_metrics
:
bool
=
False
COCO_INPUT_PATH_BASE
=
'gs://ghpark-tfrecords/coco'
#
COCO_TRAIN_EXAMPLES = 118287
COCO_TRAIN_EXAMPLES
=
96
0
COCO_TRAIN_EXAMPLES
=
118287
#COCO_TRAIN_EXAMPLES = 960
0
COCO_VAL_EXAMPLES
=
5000
@
exp_factory
.
register_config_factory
(
'detr_coco'
)
...
...
@@ -91,9 +104,15 @@ def detr_coco() -> cfg.ExperimentConfig:
train_steps
=
300
*
steps_per_epoch
# 500 epochs
decay_at
=
train_steps
-
100
*
steps_per_epoch
# 400 epochs
config
=
cfg
.
ExperimentConfig
(
task
=
DetectionConfig
(
task
=
DetrTask
(
init_checkpoint
=
'gs://ghpark-imagenet-tfrecord/ckpt/resnet101_imagenet'
,
init_checkpoint_modules
=
'backbone'
,
annotation_file
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'instances_val2017.json'
),
model
=
Detr
(
input_size
=
[
1333
,
1333
,
3
],
norm_activation
=
common
.
NormActivation
(
use_sync_bn
=
False
)),
losses
=
Losses
(),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
...
...
official/projects/detr/do_train.sh
浏览文件 @
14a9701d
...
...
@@ -2,6 +2,6 @@
python3 train.py
\
--experiment
=
detr_coco
\
--mode
=
train_and_eval
\
--model_dir
=
gs://ghpark-ckpts/detr/detr_coco/ckpt_03_
test
\
--model_dir
=
gs://ghpark-ckpts/detr/detr_coco/ckpt_03_
detr_coco_resnet101
\
--tpu
=
postech-tpu
\
--params_override
=
runtime.distribution_strategy
=
'tpu'
\ No newline at end of file
official/projects/detr/modeling/detr.py
浏览文件 @
14a9701d
...
...
@@ -24,7 +24,7 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.projects.detr.modeling
import
transformer
from
official.vision.modeling.backbones
import
resnet
#
from official.vision.modeling.backbones import resnet
def
position_embedding_sine
(
attention_mask
,
...
...
@@ -100,7 +100,7 @@ class DETR(tf.keras.Model):
class and box heads.
"""
def
__init__
(
self
,
num_queries
,
hidden_size
,
num_classes
,
def
__init__
(
self
,
backbone
,
num_queries
,
hidden_size
,
num_classes
,
num_encoder_layers
=
6
,
num_decoder_layers
=
6
,
dropout_rate
=
0.1
,
...
...
@@ -116,7 +116,9 @@ class DETR(tf.keras.Model):
raise
ValueError
(
"hidden_size must be a multiple of 2."
)
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
self
.
_backbone
=
resnet
.
ResNet
(
50
,
bn_trainable
=
False
)
#self._backbone = resnet.ResNet(50, bn_trainable=False)
# (gunho) use backbone factory
self
.
_backbone
=
backbone
def
build
(
self
,
input_shape
=
None
):
self
.
_input_proj
=
tf
.
keras
.
layers
.
Conv2D
(
...
...
official/projects/detr/tasks/detection.py
浏览文件 @
14a9701d
...
...
@@ -31,8 +31,9 @@ from official.vision.dataloaders import tf_example_decoder
from
official.vision.dataloaders
import
tfds_factory
from
official.vision.dataloaders
import
tf_example_label_map_decoder
from
official.projects.detr.dataloaders
import
detr_input
from
official.vision.modeling
import
backbones
@
task_factory
.
register_task_cls
(
detr_cfg
.
Det
ectionConfig
)
@
task_factory
.
register_task_cls
(
detr_cfg
.
Det
rTask
)
class
DectectionTask
(
base_task
.
Task
):
"""A single-replica view of training procedure.
...
...
@@ -43,12 +44,23 @@ class DectectionTask(base_task.Task):
def
build_model
(
self
):
"""Build DETR model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
_task_config
.
model
.
input_size
)
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
self
.
_task_config
.
model
.
backbone
,
norm_activation_config
=
self
.
_task_config
.
model
.
norm_activation
)
model
=
detr
.
DETR
(
self
.
_task_config
.
num_queries
,
self
.
_task_config
.
num_hidden
,
self
.
_task_config
.
num_classes
,
self
.
_task_config
.
num_encoder_layers
,
self
.
_task_config
.
num_decoder_layers
)
backbone
,
self
.
_task_config
.
model
.
num_queries
,
self
.
_task_config
.
model
.
hidden_size
,
self
.
_task_config
.
model
.
num_classes
,
self
.
_task_config
.
model
.
num_encoder_layers
,
self
.
_task_config
.
model
.
num_decoder_layers
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
...
...
@@ -99,7 +111,9 @@ class DectectionTask(base_task.Task):
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
parser
=
detr_input
.
Parser
()
parser
=
detr_input
.
Parser
(
output_size
=
self
.
_task_config
.
model
.
input_size
[:
2
],
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
...
...
@@ -114,24 +128,24 @@ class DectectionTask(base_task.Task):
# Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
cls_cost
=
self
.
_task_config
.
lambda_cls
*
tf
.
gather
(
cls_cost
=
self
.
_task_config
.
l
osses
.
l
ambda_cls
*
tf
.
gather
(
-
tf
.
nn
.
softmax
(
cls_outputs
),
cls_targets
,
batch_dims
=
1
,
axis
=-
1
)
# Compute the L1 cost between boxes,
paired_differences
=
self
.
_task_config
.
lambda_box
*
tf
.
abs
(
paired_differences
=
self
.
_task_config
.
l
osses
.
l
ambda_box
*
tf
.
abs
(
tf
.
expand_dims
(
box_outputs
,
2
)
-
tf
.
expand_dims
(
box_targets
,
1
))
box_cost
=
tf
.
reduce_sum
(
paired_differences
,
axis
=-
1
)
# Compute the giou cost betwen boxes
giou_cost
=
self
.
_task_config
.
lambda_giou
*
-
box_ops
.
bbox_generalized_overlap
(
giou_cost
=
self
.
_task_config
.
l
osses
.
l
ambda_giou
*
-
box_ops
.
bbox_generalized_overlap
(
box_ops
.
cycxhw_to_yxyx
(
box_outputs
),
box_ops
.
cycxhw_to_yxyx
(
box_targets
))
total_cost
=
cls_cost
+
box_cost
+
giou_cost
max_cost
=
(
self
.
_task_config
.
l
ambda_cls
*
0.0
+
self
.
_task_config
.
lambda_box
*
4.
+
self
.
_task_config
.
lambda_giou
*
0.0
)
self
.
_task_config
.
l
osses
.
lambda_cls
*
0.0
+
self
.
_task_config
.
losses
.
lambda_box
*
4.
+
self
.
_task_config
.
l
osses
.
l
ambda_giou
*
0.0
)
# Set pads to large constant
valid
=
tf
.
expand_dims
(
...
...
@@ -170,20 +184,20 @@ class DectectionTask(base_task.Task):
# Down-weight background to account for class imbalance.
xentropy
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
cls_targets
,
logits
=
cls_assigned
)
cls_loss
=
self
.
_task_config
.
lambda_cls
*
tf
.
where
(
cls_loss
=
self
.
_task_config
.
l
osses
.
l
ambda_cls
*
tf
.
where
(
background
,
self
.
_task_config
.
background_cls_weight
*
xentropy
,
self
.
_task_config
.
losses
.
background_cls_weight
*
xentropy
,
xentropy
)
cls_weights
=
tf
.
where
(
background
,
self
.
_task_config
.
background_cls_weight
*
tf
.
ones_like
(
cls_loss
),
self
.
_task_config
.
losses
.
background_cls_weight
*
tf
.
ones_like
(
cls_loss
),
tf
.
ones_like
(
cls_loss
)
)
# Box loss is only calculated on non-background class.
l_1
=
tf
.
reduce_sum
(
tf
.
abs
(
box_assigned
-
box_targets
),
axis
=-
1
)
box_loss
=
self
.
_task_config
.
lambda_box
*
tf
.
where
(
box_loss
=
self
.
_task_config
.
l
osses
.
l
ambda_box
*
tf
.
where
(
background
,
tf
.
zeros_like
(
l_1
),
l_1
...
...
@@ -194,7 +208,7 @@ class DectectionTask(base_task.Task):
box_ops
.
cycxhw_to_yxyx
(
box_assigned
),
box_ops
.
cycxhw_to_yxyx
(
box_targets
)
))
giou_loss
=
self
.
_task_config
.
lambda_giou
*
tf
.
where
(
giou_loss
=
self
.
_task_config
.
l
osses
.
l
ambda_giou
*
tf
.
where
(
background
,
tf
.
zeros_like
(
giou
),
giou
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录