Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
59393a8d
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
59393a8d
编写于
12月 05, 2019
作者:
Y
Yeqing Li
提交者:
A. Unique TensorFlower
12月 05, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Imports the mask-rcnn config.
PiperOrigin-RevId: 283987800
上级
c115444f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
321 addition
and
1 deletion
+321
-1
official/vision/detection/configs/base_config.py
official/vision/detection/configs/base_config.py
+160
-0
official/vision/detection/configs/factory.py
official/vision/detection/configs/factory.py
+4
-0
official/vision/detection/configs/maskrcnn_config.py
official/vision/detection/configs/maskrcnn_config.py
+156
-0
official/vision/detection/configs/retinanet_config.py
official/vision/detection/configs/retinanet_config.py
+1
-1
未找到文件。
official/vision/detection/configs/base_config.py
0 → 100644
浏览文件 @
59393a8d
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base config template."""
# pylint: disable=line-too-long
# For ResNet, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
# is able to capture low-level features such as edges; therefore, it does not
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET_FROZEN_VAR_PREFIX
=
r
'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX
=
r
'.*(kernel|weight):0$'
BASE_CFG
=
{
'model_dir'
:
''
,
'use_tpu'
:
True
,
'isolate_session_state'
:
False
,
'train'
:
{
'iterations_per_loop'
:
100
,
'train_batch_size'
:
64
,
'total_steps'
:
22500
,
'num_cores_per_replica'
:
None
,
'input_partition_dims'
:
None
,
'optimizer'
:
{
'type'
:
'momentum'
,
'momentum'
:
0.9
,
},
'learning_rate'
:
{
'type'
:
'step'
,
'warmup_learning_rate'
:
0.0067
,
'warmup_steps'
:
500
,
'init_learning_rate'
:
0.08
,
'learning_rate_levels'
:
[
0.008
,
0.0008
],
'learning_rate_steps'
:
[
15000
,
20000
],
'total_steps'
:
22500
,
},
'checkpoint'
:
{
'path'
:
''
,
'prefix'
:
''
,
},
'frozen_variable_prefix'
:
RESNET_FROZEN_VAR_PREFIX
,
'train_file_pattern'
:
''
,
'train_dataset_type'
:
'tfrecord'
,
'transpose_input'
:
True
,
'regularization_variable_regex'
:
REGULARIZATION_VAR_REGEX
,
'l2_weight_decay'
:
0.0001
,
'gradient_clip_norm'
:
0.0
,
},
'eval'
:
{
'eval_batch_size'
:
8
,
'eval_samples'
:
5000
,
'min_eval_interval'
:
180
,
'eval_timeout'
:
None
,
'num_steps_per_eval'
:
1000
,
'type'
:
'box'
,
'use_json_file'
:
True
,
'val_json_file'
:
''
,
'eval_file_pattern'
:
''
,
'eval_dataset_type'
:
'tfrecord'
,
},
'predict'
:
{
'predict_batch_size'
:
8
,
},
'anchor'
:
{
'min_level'
:
3
,
'max_level'
:
7
,
'num_scales'
:
3
,
'aspect_ratios'
:
[
1.0
,
2.0
,
0.5
],
'anchor_size'
:
4.0
,
},
'resnet'
:
{
'resnet_depth'
:
50
,
'dropblock'
:
{
'dropblock_keep_prob'
:
None
,
'dropblock_size'
:
None
,
},
'batch_norm'
:
{
'batch_norm_momentum'
:
0.997
,
'batch_norm_epsilon'
:
1e-4
,
'batch_norm_trainable'
:
True
,
'use_sync_bn'
:
False
,
},
},
'fpn'
:
{
'min_level'
:
3
,
'max_level'
:
7
,
'fpn_feat_dims'
:
256
,
'use_separable_conv'
:
False
,
'use_batch_norm'
:
True
,
'batch_norm'
:
{
'batch_norm_momentum'
:
0.997
,
'batch_norm_epsilon'
:
1e-4
,
'batch_norm_trainable'
:
True
,
'use_sync_bn'
:
False
,
},
},
'nasfpn'
:
{
'min_level'
:
3
,
'max_level'
:
7
,
'fpn_feat_dims'
:
256
,
'num_repeats'
:
5
,
'use_separable_conv'
:
False
,
'dropblock'
:
{
'dropblock_keep_prob'
:
None
,
'dropblock_size'
:
None
,
},
'batch_norm'
:
{
'batch_norm_momentum'
:
0.997
,
'batch_norm_epsilon'
:
1e-4
,
'batch_norm_trainable'
:
True
,
'use_sync_bn'
:
False
,
},
},
# tunable_nasfpn:strip_begin
'tunable_nasfpn_v1'
:
{
'min_level'
:
3
,
'max_level'
:
7
,
'fpn_feat_dims'
:
256
,
'num_repeats'
:
5
,
'use_separable_conv'
:
False
,
'dropblock'
:
{
'dropblock_keep_prob'
:
None
,
'dropblock_size'
:
None
,
},
'batch_norm'
:
{
'batch_norm_momentum'
:
0.997
,
'batch_norm_epsilon'
:
1e-4
,
'batch_norm_trainable'
:
True
,
'use_sync_bn'
:
False
,
},
'nodes'
:
None
},
# tunable_nasfpn:strip_end
'postprocess'
:
{
'use_batched_nms'
:
False
,
'max_total_size'
:
100
,
'nms_iou_threshold'
:
0.5
,
'score_threshold'
:
0.05
,
'pre_nms_num_boxes'
:
5000
,
},
'enable_summary'
:
False
,
}
# pylint: enable=line-too-long
official/vision/detection/configs/factory.py
浏览文件 @
59393a8d
...
...
@@ -14,6 +14,7 @@
# ==============================================================================
"""Factory to provide model configs."""
from
official.vision.detection.configs
import
maskrcnn_config
from
official.vision.detection.configs
import
retinanet_config
from
official.modeling.hyperparams
import
params_dict
...
...
@@ -23,6 +24,9 @@ def config_generator(model):
if
model
==
'retinanet'
:
default_config
=
retinanet_config
.
RETINANET_CFG
restrictions
=
retinanet_config
.
RETINANET_RESTRICTIONS
elif
model
==
'mask_rcnn'
:
default_config
=
maskrcnn_config
.
MASKRCNN_CFG
restrictions
=
maskrcnn_config
.
MASKRCNN_RESTRICTIONS
else
:
raise
ValueError
(
'Model %s is not supported.'
%
model
)
...
...
official/vision/detection/configs/maskrcnn_config.py
0 → 100644
浏览文件 @
59393a8d
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config template to train Mask R-CNN."""
from
official.vision.detection.configs
import
base_config
from
official.modeling.hyperparams
import
params_dict
# pylint: disable=line-too-long
MASKRCNN_CFG
=
params_dict
.
ParamsDict
(
base_config
.
BASE_CFG
)
MASKRCNN_CFG
.
override
({
'type'
:
'mask_rcnn'
,
'eval'
:
{
'type'
:
'box_and_mask'
,
},
'architecture'
:
{
'parser'
:
'maskrcnn_parser'
,
'backbone'
:
'resnet'
,
'multilevel_features'
:
'fpn'
,
'use_bfloat16'
:
True
,
'include_mask'
:
True
,
},
'maskrcnn_parser'
:
{
'use_bfloat16'
:
True
,
'output_size'
:
[
1024
,
1024
],
'rpn_match_threshold'
:
0.7
,
'rpn_unmatched_threshold'
:
0.3
,
'rpn_batch_size_per_im'
:
256
,
'rpn_fg_fraction'
:
0.5
,
'aug_rand_hflip'
:
True
,
'aug_scale_min'
:
1.0
,
'aug_scale_max'
:
1.0
,
'skip_crowd_during_training'
:
True
,
'max_num_instances'
:
100
,
'include_mask'
:
True
,
'mask_crop_size'
:
112
,
},
'anchor'
:
{
'min_level'
:
2
,
'max_level'
:
6
,
'num_scales'
:
1
,
'anchor_size'
:
8
,
},
'fpn'
:
{
'min_level'
:
2
,
'max_level'
:
6
,
},
'nasfpn'
:
{
'min_level'
:
2
,
'max_level'
:
6
,
},
# tunable_nasfpn:strip_begin
'tunable_nasfpn_v1'
:
{
'min_level'
:
2
,
'max_level'
:
6
,
},
# tunable_nasfpn:strip_end
'rpn_head'
:
{
'min_level'
:
2
,
'max_level'
:
6
,
'anchors_per_location'
:
3
,
'use_batch_norm'
:
False
,
'batch_norm'
:
{
'batch_norm_momentum'
:
0.997
,
'batch_norm_epsilon'
:
1e-4
,
'batch_norm_trainable'
:
True
,
'use_sync_bn'
:
False
,
},
},
'frcnn_head'
:
{
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes'
:
91
,
'fast_rcnn_mlp_head_dim'
:
1024
,
'use_batch_norm'
:
False
,
'batch_norm'
:
{
'batch_norm_momentum'
:
0.997
,
'batch_norm_epsilon'
:
1e-4
,
'batch_norm_trainable'
:
True
,
'use_sync_bn'
:
False
,
},
},
'mrcnn_head'
:
{
'num_classes'
:
91
,
'mask_target_size'
:
28
,
'use_batch_norm'
:
False
,
'batch_norm'
:
{
'batch_norm_momentum'
:
0.997
,
'batch_norm_epsilon'
:
1e-4
,
'batch_norm_trainable'
:
True
,
'use_sync_bn'
:
False
,
},
},
'rpn_score_loss'
:
{
'rpn_batch_size_per_im'
:
256
,
},
'rpn_box_loss'
:
{
'huber_loss_delta'
:
1.0
/
9.0
,
},
'frcnn_box_loss'
:
{
'huber_loss_delta'
:
1.0
,
},
'roi_proposal'
:
{
'rpn_pre_nms_top_k'
:
2000
,
'rpn_post_nms_top_k'
:
1000
,
'rpn_nms_threshold'
:
0.7
,
'rpn_score_threshold'
:
0.0
,
'rpn_min_size_threshold'
:
0.0
,
'test_rpn_pre_nms_top_k'
:
1000
,
'test_rpn_post_nms_top_k'
:
1000
,
'test_rpn_nms_threshold'
:
0.7
,
'test_rpn_score_threshold'
:
0.0
,
'test_rpn_min_size_threshold'
:
0.0
,
'use_batched_nms'
:
False
,
},
'roi_sampling'
:
{
'num_samples_per_image'
:
512
,
'fg_fraction'
:
0.25
,
'fg_iou_thresh'
:
0.5
,
'bg_iou_thresh_hi'
:
0.5
,
'bg_iou_thresh_lo'
:
0.0
,
'mix_gt_boxes'
:
True
,
},
'mask_sampling'
:
{
'num_mask_samples_per_image'
:
128
,
# Typically = `num_samples_per_image` * `fg_fraction`.
'mask_target_size'
:
28
,
},
'postprocess'
:
{
'use_batched_nms'
:
False
,
'max_total_size'
:
100
,
'nms_iou_threshold'
:
0.5
,
'score_threshold'
:
0.05
,
'pre_nms_num_boxes'
:
1000
,
},
},
is_strict
=
False
)
MASKRCNN_RESTRICTIONS
=
[
'architecture.use_bfloat16 == maskrcnn_parser.use_bfloat16'
,
'architecture.include_mask == maskrcnn_parser.include_mask'
,
'anchor.min_level == rpn_head.min_level'
,
'anchor.max_level == rpn_head.max_level'
,
'mrcnn_head.mask_target_size == mask_sampling.mask_target_size'
,
]
# pylint: enable=line-too-long
official/vision/detection/configs/retinanet_config.py
浏览文件 @
59393a8d
...
...
@@ -39,7 +39,7 @@ RETINANET_CFG = {
'optimizer'
:
{
'type'
:
'momentum'
,
'momentum'
:
0.9
,
'nesterov'
:
True
,
'nesterov'
:
True
,
# `False` is better for TPU v3-128.
},
'learning_rate'
:
{
'type'
:
'step'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录