Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
7313d7a9
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7313d7a9
编写于
5月 18, 2023
作者:
Z
Zihan Wang
提交者:
GitHub
5月 18, 2023
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'tensorflow:master' into master
上级
a4b4f3e3
d8df5103
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
178 addition
and
38 deletion
+178
-38
docs/vision/image_classification.ipynb
docs/vision/image_classification.ipynb
+1
-1
official/core/train_lib.py
official/core/train_lib.py
+20
-8
official/core/train_utils.py
official/core/train_utils.py
+44
-0
official/nlp/train.py
official/nlp/train.py
+10
-2
official/projects/maxvit/configs/__init__.py
official/projects/maxvit/configs/__init__.py
+1
-1
official/projects/maxvit/registry_imports.py
official/projects/maxvit/registry_imports.py
+21
-0
official/projects/maxvit/train.py
official/projects/maxvit/train.py
+2
-3
official/vision/evaluation/coco_utils.py
official/vision/evaluation/coco_utils.py
+20
-18
official/vision/evaluation/coco_utils_test.py
official/vision/evaluation/coco_utils_test.py
+50
-4
official/vision/train.py
official/vision/train.py
+9
-1
未找到文件。
docs/vision/image_classification.ipynb
浏览文件 @
7313d7a9
...
@@ -186,7 +186,7 @@
...
@@ -186,7 +186,7 @@
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
"tfds_name = 'cifar10'\n",
"tfds_name = 'cifar10'\n",
"ds,ds_info = tfds.load(\n",
"ds,ds_info = tfds.load(\n",
"tfds_name\n",
"tfds_name
,
\n",
"with_info=True)\n",
"with_info=True)\n",
"ds_info"
"ds_info"
]
]
...
...
official/core/train_lib.py
浏览文件 @
7313d7a9
...
@@ -71,6 +71,7 @@ class OrbitExperimentRunner:
...
@@ -71,6 +71,7 @@ class OrbitExperimentRunner:
controller_cls
=
orbit
.
Controller
,
controller_cls
=
orbit
.
Controller
,
summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
eval_summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
eval_summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
enable_async_checkpointing
:
bool
=
False
,
):
):
"""Constructor.
"""Constructor.
...
@@ -94,6 +95,8 @@ class OrbitExperimentRunner:
...
@@ -94,6 +95,8 @@ class OrbitExperimentRunner:
summary manager.
summary manager.
eval_summary_manager: Instance of the eval summary manager to override
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
"""
"""
self
.
strategy
=
distribution_strategy
or
tf
.
distribute
.
get_strategy
()
self
.
strategy
=
distribution_strategy
or
tf
.
distribute
.
get_strategy
()
self
.
_params
=
params
self
.
_params
=
params
...
@@ -115,7 +118,8 @@ class OrbitExperimentRunner:
...
@@ -115,7 +118,8 @@ class OrbitExperimentRunner:
save_summary
=
save_summary
,
save_summary
=
save_summary
,
train_actions
=
train_actions
,
train_actions
=
train_actions
,
eval_actions
=
eval_actions
,
eval_actions
=
eval_actions
,
controller_cls
=
controller_cls
)
controller_cls
=
controller_cls
,
enable_async_checkpointing
=
enable_async_checkpointing
)
@
property
@
property
def
params
(
self
)
->
config_definitions
.
ExperimentConfig
:
def
params
(
self
)
->
config_definitions
.
ExperimentConfig
:
...
@@ -188,13 +192,16 @@ class OrbitExperimentRunner:
...
@@ -188,13 +192,16 @@ class OrbitExperimentRunner:
checkpoint_manager
=
None
checkpoint_manager
=
None
return
checkpoint_manager
return
checkpoint_manager
def
_build_controller
(
self
,
def
_build_controller
(
self
,
trainer
,
trainer
,
evaluator
,
evaluator
,
save_summary
:
bool
=
True
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
orbit
.
Controller
:
controller_cls
=
orbit
.
Controller
,
enable_async_checkpointing
:
bool
=
False
,
)
->
orbit
.
Controller
:
"""Builds a Orbit controler."""
"""Builds a Orbit controler."""
train_actions
=
[]
if
not
train_actions
else
train_actions
train_actions
=
[]
if
not
train_actions
else
train_actions
if
trainer
:
if
trainer
:
...
@@ -223,6 +230,7 @@ class OrbitExperimentRunner:
...
@@ -223,6 +230,7 @@ class OrbitExperimentRunner:
global_step
=
self
.
trainer
.
global_step
,
global_step
=
self
.
trainer
.
global_step
,
steps_per_loop
=
self
.
params
.
trainer
.
steps_per_loop
,
steps_per_loop
=
self
.
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
self
.
checkpoint_manager
,
checkpoint_manager
=
self
.
checkpoint_manager
,
enable_async_checkpointing
=
enable_async_checkpointing
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
'train'
)
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
'train'
)
if
(
save_summary
)
if
(
save_summary
)
else
None
,
else
None
,
...
@@ -309,6 +317,7 @@ def run_experiment(
...
@@ -309,6 +317,7 @@ def run_experiment(
controller_cls
=
orbit
.
Controller
,
controller_cls
=
orbit
.
Controller
,
summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
eval_summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
eval_summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
enable_async_checkpointing
:
bool
=
False
,
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
"""Runs train/eval configured by the experiment params.
...
@@ -332,6 +341,8 @@ def run_experiment(
...
@@ -332,6 +341,8 @@ def run_experiment(
manager.
manager.
eval_summary_manager: Instance of the eval summary manager to override
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
Returns:
Returns:
A 2-tuple of (model, eval_logs).
A 2-tuple of (model, eval_logs).
...
@@ -353,5 +364,6 @@ def run_experiment(
...
@@ -353,5 +364,6 @@ def run_experiment(
controller_cls
=
controller_cls
,
controller_cls
=
controller_cls
,
summary_manager
=
summary_manager
,
summary_manager
=
summary_manager
,
eval_summary_manager
=
eval_summary_manager
,
eval_summary_manager
=
eval_summary_manager
,
enable_async_checkpointing
=
enable_async_checkpointing
,
)
)
return
runner
.
run
()
return
runner
.
run
()
official/core/train_utils.py
浏览文件 @
7313d7a9
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Training utils."""
"""Training utils."""
import
dataclasses
import
dataclasses
import
inspect
import
inspect
import
json
import
json
...
@@ -22,10 +23,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
...
@@ -22,10 +23,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
from
absl
import
logging
from
absl
import
logging
import
gin
import
gin
import
numpy
as
np
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework.convert_to_constants
import
convert_variables_to_constants_v2_as_graph
from
tensorflow.python.framework.convert_to_constants
import
convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
# pylint: enable=g-direct-tensorflow-import
from
official.core
import
base_task
from
official.core
import
base_task
...
@@ -564,3 +567,44 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model],
...
@@ -564,3 +567,44 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model],
'reached before this run.'
,
e
)
'reached before this run.'
,
e
)
return
None
return
None
return
None
return
None
@
ops
.
RegisterStatistics
(
'Einsum'
,
'flops'
)
def
_einsum_flops
(
graph
,
node
):
"""Calculates the compute resources needed for Einsum."""
assert
len
(
node
.
input
)
==
2
x_shape
=
tf
.
compat
.
v1
.
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
0
])
y_shape
=
tf
.
compat
.
v1
.
graph_util
.
tensor_shape_from_node_def_name
(
graph
,
node
.
input
[
1
])
x_shape
.
assert_is_fully_defined
()
y_shape
.
assert_is_fully_defined
()
x_shape
=
x_shape
.
as_list
()
y_shape
=
y_shape
.
as_list
()
equation
=
str
(
node
.
attr
[
'equation'
])
equation
=
(
equation
.
replace
(
's:'
,
''
)
.
replace
(
'"'
,
''
)
.
replace
(
' '
,
''
)
.
replace
(
'
\n
'
,
''
)
)
x_str
=
equation
.
split
(
','
)[
0
]
y_r_str
=
equation
.
split
(
','
)[
1
]
y_str
=
y_r_str
.
split
(
'->'
)[
0
]
r_str
=
y_r_str
.
split
(
'->'
)[
1
]
shape_dic
=
{}
contracted
=
set
()
for
indice
in
x_str
+
y_str
:
if
indice
in
x_str
:
indice_dim
=
x_shape
[
x_str
.
find
(
indice
)]
elif
indice
in
y_str
:
indice_dim
=
y_shape
[
y_str
.
find
(
indice
)]
else
:
raise
ValueError
(
'indice {} not found in inputs'
.
format
(
indice
))
shape_dic
[
indice
]
=
indice_dim
if
indice
not
in
r_str
:
contracted
.
add
(
indice
)
madds
=
np
.
prod
([
shape_dic
[
indice
]
for
indice
in
r_str
])
*
(
np
.
prod
([
shape_dic
[
indice
]
for
indice
in
contracted
]))
flops
=
2
*
madds
return
ops
.
OpStats
(
'flops'
,
flops
)
official/nlp/train.py
浏览文件 @
7313d7a9
...
@@ -38,6 +38,11 @@ flags.DEFINE_integer(
...
@@ -38,6 +38,11 @@ flags.DEFINE_integer(
default
=
None
,
default
=
None
,
help
=
'The number of total training steps for the pretraining job.'
)
help
=
'The number of total training steps for the pretraining job.'
)
flags
.
DEFINE_bool
(
'enable_async_checkpointing'
,
default
=
True
,
help
=
'A boolean indicating whether to enable async checkpoint saving'
)
def
_run_experiment_with_preemption_recovery
(
params
,
model_dir
):
def
_run_experiment_with_preemption_recovery
(
params
,
model_dir
):
"""Runs experiment and tries to reconnect when encounting a preemption."""
"""Runs experiment and tries to reconnect when encounting a preemption."""
...
@@ -53,14 +58,17 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
...
@@ -53,14 +58,17 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
**
params
.
runtime
.
model_parallelism
())
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
preemption_watcher
=
tf
.
distribute
.
experimental
.
PreemptionWatcher
()
# pylint: disable=line-too-long
preemption_watcher
=
None
# copybara-replace
# pylint: enable=line-too-long
train_lib
.
run_experiment
(
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
task
=
task
,
task
=
task
,
mode
=
FLAGS
.
mode
,
mode
=
FLAGS
.
mode
,
params
=
params
,
params
=
params
,
model_dir
=
model_dir
)
model_dir
=
model_dir
,
enable_async_checkpointing
=
FLAGS
.
enable_async_checkpointing
)
keep_training
=
False
keep_training
=
False
except
tf
.
errors
.
OpError
as
e
:
except
tf
.
errors
.
OpError
as
e
:
...
...
official/projects/maxvit/configs/__init__.py
浏览文件 @
7313d7a9
...
@@ -19,4 +19,4 @@ from official.projects.maxvit.configs import backbones # pylint:disable=unused-
...
@@ -19,4 +19,4 @@ from official.projects.maxvit.configs import backbones # pylint:disable=unused-
from
official.projects.maxvit.configs
import
rcnn
# pylint:disable=unused-import
from
official.projects.maxvit.configs
import
rcnn
# pylint:disable=unused-import
from
official.projects.maxvit.configs
import
retinanet
# pylint:disable=unused-import
from
official.projects.maxvit.configs
import
retinanet
# pylint:disable=unused-import
from
official.projects.maxvit.configs
import
semantic_segmentation
# pylint:disable=unused-import
from
official.projects.maxvit.configs
import
semantic_segmentation
# pylint:disable=unused-import
from
official.projects.maxvit.configs
.google
import
image_classification
# pylint:disable=unused-import
from
official.projects.maxvit.configs
import
image_classification
# pylint:disable=unused-import
official/projects/maxvit/registry_imports.py
0 → 100644
浏览文件 @
7313d7a9
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
# pylint: disable=g-bad-import-order
from
official.vision
import
registry_imports
from
official.projects.maxvit
import
configs
# pylint: disable=unused-import
from
official.projects.maxvit.modeling
import
maxvit
# pylint: disable=unused-import
official/projects/maxvit/train.py
浏览文件 @
7313d7a9
...
@@ -12,12 +12,11 @@
...
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""TensorFlow Model Garden Vision training driver, including ViT configs.."""
"""TensorFlow Model Garden Vision training driver, including
Max
ViT configs.."""
from
absl
import
app
from
absl
import
app
from
official.common
import
flags
as
tfm_flags
from
official.common
import
flags
as
tfm_flags
from
official.projects.maxvit
import
configs
# pylint: disable=unused-import
from
official.projects.maxvit
import
registry_imports
# pylint: disable=unused-import
from
official.projects.maxvit.modeling
import
maxvit
# pylint: disable=unused-import
from
official.vision
import
train
from
official.vision
import
train
...
...
official/vision/evaluation/coco_utils.py
浏览文件 @
7313d7a9
...
@@ -114,7 +114,6 @@ def convert_predictions_to_coco_annotations(predictions):
...
@@ -114,7 +114,6 @@ def convert_predictions_to_coco_annotations(predictions):
Required fields:
Required fields:
- source_id: a list of numpy arrays of int or string of shape
- source_id: a list of numpy arrays of int or string of shape
[batch_size].
[batch_size].
- num_detections: a list of numpy arrays of int of shape [batch_size].
- detection_boxes: a list of numpy arrays of float of shape
- detection_boxes: a list of numpy arrays of float of shape
[batch_size, K, 4], where coordinates are in the original image
[batch_size, K, 4], where coordinates are in the original image
space (not the scaled image space).
space (not the scaled image space).
...
@@ -125,6 +124,8 @@ def convert_predictions_to_coco_annotations(predictions):
...
@@ -125,6 +124,8 @@ def convert_predictions_to_coco_annotations(predictions):
Optional fields:
Optional fields:
- detection_masks: a list of numpy arrays of float of shape
- detection_masks: a list of numpy arrays of float of shape
[batch_size, K, mask_height, mask_width].
[batch_size, K, mask_height, mask_width].
- detection_keypoints: a list of numpy arrays of float of shape
[batch_size, K, num_keypoints, 2]
Returns:
Returns:
coco_predictions: prediction in COCO annotation format.
coco_predictions: prediction in COCO annotation format.
...
@@ -144,17 +145,32 @@ def convert_predictions_to_coco_annotations(predictions):
...
@@ -144,17 +145,32 @@ def convert_predictions_to_coco_annotations(predictions):
mask_boxes
=
predictions
[
'detection_boxes'
]
mask_boxes
=
predictions
[
'detection_boxes'
]
batch_size
=
predictions
[
'source_id'
][
i
].
shape
[
0
]
batch_size
=
predictions
[
'source_id'
][
i
].
shape
[
0
]
if
'detection_keypoints'
in
predictions
:
# Adds extra ones to indicate the visibility for each keypoint as is
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
# as mandated by COCO.
num_keypoints
=
predictions
[
'detection_keypoints'
][
i
].
shape
[
2
]
coco_keypoints
=
np
.
concatenate
(
[
predictions
[
'detection_keypoints'
][
i
][...,
1
:],
predictions
[
'detection_keypoints'
][
i
][...,
:
1
],
np
.
ones
([
batch_size
,
max_num_detections
,
num_keypoints
,
1
]),
],
axis
=-
1
,
).
astype
(
int
)
for
j
in
range
(
batch_size
):
for
j
in
range
(
batch_size
):
if
'detection_masks'
in
predictions
:
if
'detection_masks'
in
predictions
:
image_masks
=
mask_ops
.
paste_instance_masks
(
image_masks
=
mask_ops
.
paste_instance_masks
(
predictions
[
'detection_masks'
][
i
][
j
],
predictions
[
'detection_masks'
][
i
][
j
],
mask_boxes
[
i
][
j
],
mask_boxes
[
i
][
j
],
int
(
predictions
[
'image_info'
][
i
][
j
,
0
,
0
]),
int
(
predictions
[
'image_info'
][
i
][
j
,
0
,
0
]),
int
(
predictions
[
'image_info'
][
i
][
j
,
0
,
1
]))
int
(
predictions
[
'image_info'
][
i
][
j
,
0
,
1
]),
)
binary_masks
=
(
image_masks
>
0.0
).
astype
(
np
.
uint8
)
binary_masks
=
(
image_masks
>
0.0
).
astype
(
np
.
uint8
)
encoded_masks
=
[
encoded_masks
=
[
mask_api
.
encode
(
np
.
asfortranarray
(
binary_mask
))
mask_api
.
encode
(
np
.
asfortranarray
(
binary_mask
))
for
binary_mask
in
list
(
binary_masks
)]
for
binary_mask
in
list
(
binary_masks
)
]
for
k
in
range
(
max_num_detections
):
for
k
in
range
(
max_num_detections
):
ann
=
{}
ann
=
{}
ann
[
'image_id'
]
=
predictions
[
'source_id'
][
i
][
j
]
ann
[
'image_id'
]
=
predictions
[
'source_id'
][
i
][
j
]
...
@@ -164,21 +180,7 @@ def convert_predictions_to_coco_annotations(predictions):
...
@@ -164,21 +180,7 @@ def convert_predictions_to_coco_annotations(predictions):
if
'detection_masks'
in
predictions
:
if
'detection_masks'
in
predictions
:
ann
[
'segmentation'
]
=
encoded_masks
[
k
]
ann
[
'segmentation'
]
=
encoded_masks
[
k
]
if
'detection_keypoints'
in
predictions
:
if
'detection_keypoints'
in
predictions
:
# Adds extra ones to indicate the visibility for each keypoint as is
ann
[
'keypoints'
]
=
coco_keypoints
[
j
,
k
].
flatten
().
tolist
()
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
# as mandated by COCO.
instance_keypoints
=
predictions
[
'detection_keypoints'
][
i
][
j
,
k
]
num_keypoints
=
len
(
instance_keypoints
)
instance_keypoints
=
np
.
concatenate
(
[
np
.
expand_dims
(
instance_keypoints
[:,
1
],
axis
=-
1
),
np
.
expand_dims
(
instance_keypoints
[:,
0
],
axis
=-
1
),
np
.
expand_dims
(
np
.
ones
(
num_keypoints
),
axis
=
1
),
],
axis
=
1
,
).
astype
(
int
)
instance_keypoints
=
instance_keypoints
.
flatten
().
tolist
()
ann
[
'keypoints'
]
=
instance_keypoints
coco_predictions
.
append
(
ann
)
coco_predictions
.
append
(
ann
)
for
i
,
ann
in
enumerate
(
coco_predictions
):
for
i
,
ann
in
enumerate
(
coco_predictions
):
...
...
official/vision/evaluation/coco_utils_test.py
浏览文件 @
7313d7a9
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
os
import
os
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.dataloaders
import
tfexample_utils
from
official.vision.dataloaders
import
tfexample_utils
...
@@ -27,11 +28,13 @@ class CocoUtilsTest(tf.test.TestCase):
...
@@ -27,11 +28,13 @@ class CocoUtilsTest(tf.test.TestCase):
def
test_scan_and_generator_annotation_file
(
self
):
def
test_scan_and_generator_annotation_file
(
self
):
num_samples
=
10
num_samples
=
10
example
=
tfexample_utils
.
create_detection_test_example
(
example
=
tfexample_utils
.
create_detection_test_example
(
image_height
=
512
,
image_width
=
512
,
image_channel
=
3
,
num_instances
=
10
)
image_height
=
512
,
image_width
=
512
,
image_channel
=
3
,
num_instances
=
10
)
tf_examples
=
[
example
]
*
num_samples
tf_examples
=
[
example
]
*
num_samples
data_file
=
os
.
path
.
join
(
self
.
create_tempdir
(),
'test.tfrecord'
)
data_file
=
os
.
path
.
join
(
self
.
create_tempdir
(),
'test.tfrecord'
)
tfexample_utils
.
dump_to_tfrecord
(
tfexample_utils
.
dump_to_tfrecord
(
record_file
=
data_file
,
tf_examples
=
tf_examples
)
record_file
=
data_file
,
tf_examples
=
tf_examples
)
annotation_file
=
os
.
path
.
join
(
self
.
create_tempdir
(),
'annotation.json'
)
annotation_file
=
os
.
path
.
join
(
self
.
create_tempdir
(),
'annotation.json'
)
coco_utils
.
scan_and_generator_annotation_file
(
coco_utils
.
scan_and_generator_annotation_file
(
...
@@ -39,10 +42,53 @@ class CocoUtilsTest(tf.test.TestCase):
...
@@ -39,10 +42,53 @@ class CocoUtilsTest(tf.test.TestCase):
file_type
=
'tfrecord'
,
file_type
=
'tfrecord'
,
num_samples
=
num_samples
,
num_samples
=
num_samples
,
include_mask
=
True
,
include_mask
=
True
,
annotation_file
=
annotation_file
)
annotation_file
=
annotation_file
,
)
self
.
assertTrue
(
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
annotation_file
),
tf
.
io
.
gfile
.
exists
(
annotation_file
),
msg
=
'Annotation file {annotation_file} does not exists.'
)
msg
=
'Annotation file {annotation_file} does not exist.'
,
)
def
test_convert_keypoint_predictions_to_coco_annotations
(
self
):
batch_size
=
1
max_num_detections
=
3
num_keypoints
=
3
image_size
=
512
source_id
=
[
np
.
array
([[
1
]],
dtype
=
int
)]
detection_boxes
=
[
np
.
random
.
random
([
batch_size
,
max_num_detections
,
4
])
*
image_size
]
detection_class
=
[
np
.
random
.
randint
(
1
,
5
,
[
batch_size
,
max_num_detections
])
]
detection_scores
=
[
np
.
random
.
random
([
batch_size
,
max_num_detections
])]
detection_keypoints
=
[
np
.
random
.
random
([
batch_size
,
max_num_detections
,
num_keypoints
,
2
])
*
image_size
]
predictions
=
{
'source_id'
:
source_id
,
'detection_boxes'
:
detection_boxes
,
'detection_classes'
:
detection_class
,
'detection_scores'
:
detection_scores
,
'detection_keypoints'
:
detection_keypoints
,
}
anns
=
coco_utils
.
convert_predictions_to_coco_annotations
(
predictions
)
for
i
in
range
(
max_num_detections
):
expected_keypoint_ann
=
np
.
concatenate
(
[
np
.
expand_dims
(
detection_keypoints
[
0
][
0
,
i
,
:,
1
],
axis
=-
1
),
np
.
expand_dims
(
detection_keypoints
[
0
][
0
,
i
,
:,
0
],
axis
=-
1
),
np
.
expand_dims
(
np
.
ones
(
num_keypoints
),
axis
=
1
),
],
axis
=
1
,
).
astype
(
int
)
expected_keypoint_ann
=
expected_keypoint_ann
.
flatten
().
tolist
()
self
.
assertAllEqual
(
anns
[
i
][
'keypoints'
],
expected_keypoint_ann
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/vision/train.py
浏览文件 @
7313d7a9
...
@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager
...
@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_bool
(
'enable_async_checkpointing'
,
default
=
True
,
help
=
'A boolean indicating whether to enable async checkpoint saving'
)
def
_run_experiment_with_preemption_recovery
(
params
,
model_dir
):
def
_run_experiment_with_preemption_recovery
(
params
,
model_dir
):
"""Runs experiment and tries to reconnect when encounting a preemption."""
"""Runs experiment and tries to reconnect when encounting a preemption."""
...
@@ -46,7 +51,9 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
...
@@ -46,7 +51,9 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
tpu_address
=
params
.
runtime
.
tpu
)
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
preemption_watcher
=
tf
.
distribute
.
experimental
.
PreemptionWatcher
()
# pylint: disable=line-too-long
preemption_watcher
=
None
# copybara-replace
# pylint: enable=line-too-long
train_lib
.
run_experiment
(
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
...
@@ -58,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
...
@@ -58,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
eval_summary_manager
=
summary_manager
.
maybe_build_eval_summary_manager
(
eval_summary_manager
=
summary_manager
.
maybe_build_eval_summary_manager
(
params
=
params
,
model_dir
=
model_dir
params
=
params
,
model_dir
=
model_dir
),
),
enable_async_checkpointing
=
FLAGS
.
enable_async_checkpointing
,
)
)
keep_training
=
False
keep_training
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录