Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
986ffac4
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
986ffac4
编写于
11月 22, 2019
作者:
Y
Yeqing Li
提交者:
A. Unique TensorFlower
11月 22, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 282065024
上级
a9387332
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
41 addition
and
13 deletion
+41
-13
official/vision/detection/configs/retinanet_config.py
official/vision/detection/configs/retinanet_config.py
+6
-1
official/vision/detection/dataloader/input_reader.py
official/vision/detection/dataloader/input_reader.py
+11
-1
official/vision/detection/executor/detection_executor.py
official/vision/detection/executor/detection_executor.py
+8
-5
official/vision/detection/modeling/base_model.py
official/vision/detection/modeling/base_model.py
+10
-4
official/vision/detection/modeling/postprocess.py
official/vision/detection/modeling/postprocess.py
+2
-1
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+4
-1
未找到文件。
official/vision/detection/configs/retinanet_config.py
浏览文件 @
986ffac4
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
# Note that we need to trailing `/` to avoid the incorrect match.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX
=
r
'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET50_FROZEN_VAR_PREFIX
=
r
'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX
=
r
'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
# pylint: disable=line-too-long
# pylint: disable=line-too-long
...
@@ -38,6 +39,7 @@ RETINANET_CFG = {
...
@@ -38,6 +39,7 @@ RETINANET_CFG = {
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'momentum'
,
'type'
:
'momentum'
,
'momentum'
:
0.9
,
'momentum'
:
0.9
,
'nesterov'
:
False
,
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'step'
,
'type'
:
'step'
,
...
@@ -56,6 +58,7 @@ RETINANET_CFG = {
...
@@ -56,6 +58,7 @@ RETINANET_CFG = {
# TODO(b/142174042): Support transpose_input option.
# TODO(b/142174042): Support transpose_input option.
'transpose_input'
:
False
,
'transpose_input'
:
False
,
'l2_weight_decay'
:
0.0001
,
'l2_weight_decay'
:
0.0001
,
'input_sharding'
:
False
,
},
},
'eval'
:
{
'eval'
:
{
'batch_size'
:
8
,
'batch_size'
:
8
,
...
@@ -65,6 +68,7 @@ RETINANET_CFG = {
...
@@ -65,6 +68,7 @@ RETINANET_CFG = {
'type'
:
'box'
,
'type'
:
'box'
,
'val_json_file'
:
''
,
'val_json_file'
:
''
,
'eval_file_pattern'
:
''
,
'eval_file_pattern'
:
''
,
'input_sharding'
:
True
,
},
},
'predict'
:
{
'predict'
:
{
'predict_batch_size'
:
8
,
'predict_batch_size'
:
8
,
...
@@ -165,7 +169,8 @@ RETINANET_CFG = {
...
@@ -165,7 +169,8 @@ RETINANET_CFG = {
'num_classes'
:
91
,
'num_classes'
:
91
,
'max_total_size'
:
100
,
'max_total_size'
:
100
,
'nms_iou_threshold'
:
0.5
,
'nms_iou_threshold'
:
0.5
,
'score_threshold'
:
0.05
'score_threshold'
:
0.05
,
'pre_nms_num_boxes'
:
5000
,
},
},
'enable_summary'
:
False
,
'enable_summary'
:
False
,
}
}
...
...
official/vision/detection/dataloader/input_reader.py
浏览文件 @
986ffac4
...
@@ -58,6 +58,15 @@ class InputFn(object):
...
@@ -58,6 +58,15 @@ class InputFn(object):
self
.
_parser_fn
=
factory
.
parser_generator
(
params
,
mode
)
self
.
_parser_fn
=
factory
.
parser_generator
(
params
,
mode
)
self
.
_dataset_fn
=
tf
.
data
.
TFRecordDataset
self
.
_dataset_fn
=
tf
.
data
.
TFRecordDataset
self
.
_input_sharding
=
(
not
self
.
_is_training
)
try
:
if
self
.
_is_training
:
self
.
_input_sharding
=
params
.
train
.
input_sharding
else
:
self
.
_input_sharding
=
params
.
eval
.
input_sharding
except
KeyError
:
pass
def
__call__
(
self
,
ctx
=
None
,
batch_size
:
int
=
None
):
def
__call__
(
self
,
ctx
=
None
,
batch_size
:
int
=
None
):
"""Provides tf.data.Dataset object.
"""Provides tf.data.Dataset object.
...
@@ -74,7 +83,7 @@ class InputFn(object):
...
@@ -74,7 +83,7 @@ class InputFn(object):
dataset
=
tf
.
data
.
Dataset
.
list_files
(
dataset
=
tf
.
data
.
Dataset
.
list_files
(
self
.
_file_pattern
,
shuffle
=
self
.
_is_training
)
self
.
_file_pattern
,
shuffle
=
self
.
_is_training
)
if
ctx
and
ctx
.
num_input_pipelines
>
1
:
if
self
.
_input_sharding
and
ctx
and
ctx
.
num_input_pipelines
>
1
:
dataset
=
dataset
.
shard
(
ctx
.
num_input_pipelines
,
ctx
.
input_pipeline_id
)
dataset
=
dataset
.
shard
(
ctx
.
num_input_pipelines
,
ctx
.
input_pipeline_id
)
if
self
.
_is_training
:
if
self
.
_is_training
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
...
@@ -82,6 +91,7 @@ class InputFn(object):
...
@@ -82,6 +91,7 @@ class InputFn(object):
dataset
=
dataset
.
interleave
(
dataset
=
dataset
.
interleave
(
map_func
=
lambda
file_name
:
self
.
_dataset_fn
(
file_name
),
cycle_length
=
32
,
map_func
=
lambda
file_name
:
self
.
_dataset_fn
(
file_name
),
cycle_length
=
32
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
cache
()
if
self
.
_is_training
:
if
self
.
_is_training
:
dataset
=
dataset
.
shuffle
(
64
)
dataset
=
dataset
.
shuffle
(
64
)
...
...
official/vision/detection/executor/detection_executor.py
浏览文件 @
986ffac4
...
@@ -58,6 +58,13 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
...
@@ -58,6 +58,13 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables
)
trainable_variables
)
logging
.
info
(
'Filter trainable variables from %d to %d'
,
logging
.
info
(
'Filter trainable variables from %d to %d'
,
len
(
model
.
trainable_variables
),
len
(
trainable_variables
))
len
(
model
.
trainable_variables
),
len
(
trainable_variables
))
_update_state
=
lambda
labels
,
outputs
:
None
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
_update_state
=
lambda
labels
,
outputs
:
metric
.
update_state
(
labels
,
outputs
)
else
:
logging
.
error
(
'Detection: train metric is not an instance of '
'tf.keras.metrics.Metric.'
)
def
_replicated_step
(
inputs
):
def
_replicated_step
(
inputs
):
"""Replicated training step."""
"""Replicated training step."""
...
@@ -71,11 +78,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
...
@@ -71,11 +78,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
v
=
tf
.
reduce_mean
(
v
)
/
strategy
.
num_replicas_in_sync
v
=
tf
.
reduce_mean
(
v
)
/
strategy
.
num_replicas_in_sync
losses
[
k
]
=
v
losses
[
k
]
=
v
loss
=
losses
[
'total_loss'
]
loss
=
losses
[
'total_loss'
]
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
_update_state
(
labels
,
outputs
)
metric
.
update_state
(
labels
,
outputs
)
else
:
logging
.
error
(
'train metric is not an instance of '
'tf.keras.metrics.Metric.'
)
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
trainable_variables
))
optimizer
.
apply_gradients
(
zip
(
grads
,
trainable_variables
))
...
...
official/vision/detection/modeling/base_model.py
浏览文件 @
986ffac4
...
@@ -36,8 +36,15 @@ class OptimizerFactory(object):
...
@@ -36,8 +36,15 @@ class OptimizerFactory(object):
def
__init__
(
self
,
params
):
def
__init__
(
self
,
params
):
"""Creates optimized based on the specified flags."""
"""Creates optimized based on the specified flags."""
if
params
.
type
==
'momentum'
:
if
params
.
type
==
'momentum'
:
nesterov
=
False
try
:
nesterov
=
params
.
nesterov
except
KeyError
:
pass
self
.
_optimizer
=
functools
.
partial
(
self
.
_optimizer
=
functools
.
partial
(
tf
.
keras
.
optimizers
.
SGD
,
momentum
=
0.9
,
nesterov
=
True
)
tf
.
keras
.
optimizers
.
SGD
,
momentum
=
params
.
momentum
,
nesterov
=
nesterov
)
elif
params
.
type
==
'adam'
:
elif
params
.
type
==
'adam'
:
self
.
_optimizer
=
tf
.
keras
.
optimizers
.
Adam
self
.
_optimizer
=
tf
.
keras
.
optimizers
.
Adam
elif
params
.
type
==
'adadelta'
:
elif
params
.
type
==
'adadelta'
:
...
@@ -133,11 +140,10 @@ class Model(object):
...
@@ -133,11 +140,10 @@ class Model(object):
"""
"""
return
_make_filter_trainable_variables_fn
(
self
.
_frozen_variable_prefix
)
return
_make_filter_trainable_variables_fn
(
self
.
_frozen_variable_prefix
)
def
weight_decay_loss
(
self
,
l2_weight_decay
,
keras_model
):
def
weight_decay_loss
(
self
,
l2_weight_decay
,
trainable_variables
):
# TODO(yeqing): Correct the filter according to cr/269707763.
return
l2_weight_decay
*
tf
.
add_n
([
return
l2_weight_decay
*
tf
.
add_n
([
tf
.
nn
.
l2_loss
(
v
)
tf
.
nn
.
l2_loss
(
v
)
for
v
in
self
.
_keras_model
.
trainable_variables
for
v
in
trainable_variables
if
'batch_normalization'
not
in
v
.
name
and
'bias'
not
in
v
.
name
if
'batch_normalization'
not
in
v
.
name
and
'bias'
not
in
v
.
name
])
])
...
...
official/vision/detection/modeling/postprocess.py
浏览文件 @
986ffac4
...
@@ -40,7 +40,8 @@ def generate_detections_factory(params):
...
@@ -40,7 +40,8 @@ def generate_detections_factory(params):
_generate_detections
,
_generate_detections
,
max_total_size
=
params
.
max_total_size
,
max_total_size
=
params
.
max_total_size
,
nms_iou_threshold
=
params
.
nms_iou_threshold
,
nms_iou_threshold
=
params
.
nms_iou_threshold
,
score_threshold
=
params
.
score_threshold
)
score_threshold
=
params
.
score_threshold
,
pre_nms_num_boxes
=
params
.
pre_nms_num_boxes
)
return
func
return
func
...
...
official/vision/detection/modeling/retinanet_model.py
浏览文件 @
986ffac4
...
@@ -120,6 +120,9 @@ class RetinanetModel(base_model.Model):
...
@@ -120,6 +120,9 @@ class RetinanetModel(base_model.Model):
if
self
.
_keras_model
is
None
:
if
self
.
_keras_model
is
None
:
raise
ValueError
(
'build_loss_fn() must be called after build_model().'
)
raise
ValueError
(
'build_loss_fn() must be called after build_model().'
)
filter_fn
=
self
.
make_filter_trainable_variables_fn
()
trainable_variables
=
filter_fn
(
self
.
_keras_model
.
trainable_variables
)
def
_total_loss_fn
(
labels
,
outputs
):
def
_total_loss_fn
(
labels
,
outputs
):
cls_loss
=
self
.
_cls_loss_fn
(
outputs
[
'cls_outputs'
],
cls_loss
=
self
.
_cls_loss_fn
(
outputs
[
'cls_outputs'
],
labels
[
'cls_targets'
],
labels
[
'cls_targets'
],
...
@@ -129,7 +132,7 @@ class RetinanetModel(base_model.Model):
...
@@ -129,7 +132,7 @@ class RetinanetModel(base_model.Model):
labels
[
'num_positives'
])
labels
[
'num_positives'
])
model_loss
=
cls_loss
+
self
.
_box_loss_weight
*
box_loss
model_loss
=
cls_loss
+
self
.
_box_loss_weight
*
box_loss
l2_regularization_loss
=
self
.
weight_decay_loss
(
self
.
_l2_weight_decay
,
l2_regularization_loss
=
self
.
weight_decay_loss
(
self
.
_l2_weight_decay
,
self
.
_keras_model
)
trainable_variables
)
total_loss
=
model_loss
+
l2_regularization_loss
total_loss
=
model_loss
+
l2_regularization_loss
return
{
return
{
'total_loss'
:
total_loss
,
'total_loss'
:
total_loss
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录