Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
05ef569f
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
05ef569f
编写于
9月 18, 2021
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 397467113
上级
9189cf5e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
7 deletion
+34
-7
official/vision/beta/modeling/retinanet_model.py
official/vision/beta/modeling/retinanet_model.py
+17
-6
official/vision/beta/modeling/retinanet_model_test.py
official/vision/beta/modeling/retinanet_model_test.py
+17
-1
未找到文件。
official/vision/beta/modeling/retinanet_model.py
浏览文件 @
05ef569f
...
...
@@ -77,6 +77,7 @@ class RetinaNetModel(tf.keras.Model):
images
:
tf
.
Tensor
,
image_shape
:
Optional
[
tf
.
Tensor
]
=
None
,
anchor_boxes
:
Optional
[
Mapping
[
str
,
tf
.
Tensor
]]
=
None
,
output_intermediate_features
:
bool
=
False
,
training
:
bool
=
None
)
->
Mapping
[
str
,
tf
.
Tensor
]:
"""Forward pass of the RetinaNet model.
...
...
@@ -92,6 +93,8 @@ class RetinaNetModel(tf.keras.Model):
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the anchor coordinates of a particular feature
level, whose shape is [height_l, width_l, num_anchors_per_location].
output_intermediate_features: `bool` indicating whether to return the
intermediate feature maps generated by backbone and decoder.
training: `bool`, indicating whether it is in training mode.
Returns:
...
...
@@ -112,19 +115,26 @@ class RetinaNetModel(tf.keras.Model):
feature level, whose shape is
[batch, height_l, width_l, att_size * num_anchors_per_location].
"""
outputs
=
{}
# Feature extraction.
features
=
self
.
backbone
(
images
)
if
output_intermediate_features
:
outputs
.
update
(
{
'backbone_{}'
.
format
(
k
):
v
for
k
,
v
in
features
.
items
()})
if
self
.
decoder
:
features
=
self
.
decoder
(
features
)
if
output_intermediate_features
:
outputs
.
update
(
{
'decoder_{}'
.
format
(
k
):
v
for
k
,
v
in
features
.
items
()})
# Dense prediction. `raw_attributes` can be empty.
raw_scores
,
raw_boxes
,
raw_attributes
=
self
.
head
(
features
)
if
training
:
outputs
=
{
outputs
.
update
(
{
'cls_outputs'
:
raw_scores
,
'box_outputs'
:
raw_boxes
,
}
}
)
if
raw_attributes
:
outputs
.
update
({
'attribute_outputs'
:
raw_attributes
})
return
outputs
...
...
@@ -145,12 +155,13 @@ class RetinaNetModel(tf.keras.Model):
[
tf
.
shape
(
images
)[
0
],
1
,
1
,
1
])
# Post-processing.
final_results
=
self
.
detection_generator
(
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
outputs
=
{
final_results
=
self
.
detection_generator
(
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
outputs
.
update
({
'cls_outputs'
:
raw_scores
,
'box_outputs'
:
raw_boxes
,
}
}
)
if
self
.
detection_generator
.
get_config
()[
'apply_nms'
]:
outputs
.
update
({
'detection_boxes'
:
final_results
[
'detection_boxes'
],
...
...
official/vision/beta/modeling/retinanet_model_test.py
浏览文件 @
05ef569f
...
...
@@ -147,8 +147,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
],
training
=
[
True
,
False
],
has_att_heads
=
[
True
,
False
],
output_intermediate_features
=
[
True
,
False
],
))
def
test_forward
(
self
,
strategy
,
image_size
,
training
,
has_att_heads
):
def
test_forward
(
self
,
strategy
,
image_size
,
training
,
has_att_heads
,
output_intermediate_features
):
"""Test for creation of a R50-FPN RetinaNet."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
num_classes
=
3
...
...
@@ -202,6 +204,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
images
,
image_shape
,
anchor_boxes
,
output_intermediate_features
=
output_intermediate_features
,
training
=
training
)
if
training
:
...
...
@@ -247,6 +250,19 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
(
[
2
,
10
,
1
],
model_outputs
[
'detection_attributes'
][
'depth'
].
numpy
().
shape
)
if
output_intermediate_features
:
for
l
in
range
(
2
,
6
):
self
.
assertIn
(
'backbone_{}'
.
format
(
l
),
model_outputs
)
self
.
assertAllEqual
([
2
,
image_size
[
0
]
//
2
**
l
,
image_size
[
1
]
//
2
**
l
,
backbone
.
output_specs
[
str
(
l
)].
as_list
()[
-
1
]
],
model_outputs
[
'backbone_{}'
.
format
(
l
)].
numpy
().
shape
)
for
l
in
range
(
min_level
,
max_level
+
1
):
self
.
assertIn
(
'decoder_{}'
.
format
(
l
),
model_outputs
)
self
.
assertAllEqual
([
2
,
image_size
[
0
]
//
2
**
l
,
image_size
[
1
]
//
2
**
l
,
decoder
.
output_specs
[
str
(
l
)].
as_list
()[
-
1
]
],
model_outputs
[
'decoder_{}'
.
format
(
l
)].
numpy
().
shape
)
def
test_serialize_deserialize
(
self
):
"""Validate the network can be serialized and deserialized."""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录