Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
6e2a1d5e
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6e2a1d5e
编写于
5月 17, 2023
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enable async checkpoint by default in Tensorflow model garden.
PiperOrigin-RevId: 532860554
上级
8f17df9a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
33 addition
and
9 deletion
+33
-9
official/core/train_lib.py
official/core/train_lib.py
+20
-8
official/nlp/train.py
official/nlp/train.py
+7
-1
official/vision/train.py
official/vision/train.py
+6
-0
未找到文件。
official/core/train_lib.py
浏览文件 @
6e2a1d5e
...
...
@@ -71,6 +71,7 @@ class OrbitExperimentRunner:
controller_cls
=
orbit
.
Controller
,
summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
eval_summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
enable_async_checkpointing
:
bool
=
False
,
):
"""Constructor.
...
...
@@ -94,6 +95,8 @@ class OrbitExperimentRunner:
summary manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
"""
self
.
strategy
=
distribution_strategy
or
tf
.
distribute
.
get_strategy
()
self
.
_params
=
params
...
...
@@ -115,7 +118,8 @@ class OrbitExperimentRunner:
save_summary
=
save_summary
,
train_actions
=
train_actions
,
eval_actions
=
eval_actions
,
controller_cls
=
controller_cls
)
controller_cls
=
controller_cls
,
enable_async_checkpointing
=
enable_async_checkpointing
)
@
property
def
params
(
self
)
->
config_definitions
.
ExperimentConfig
:
...
...
@@ -188,13 +192,16 @@ class OrbitExperimentRunner:
checkpoint_manager
=
None
return
checkpoint_manager
def
_build_controller
(
self
,
trainer
,
evaluator
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
orbit
.
Controller
:
def
_build_controller
(
self
,
trainer
,
evaluator
,
save_summary
:
bool
=
True
,
train_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
eval_actions
:
Optional
[
List
[
orbit
.
Action
]]
=
None
,
controller_cls
=
orbit
.
Controller
,
enable_async_checkpointing
:
bool
=
False
,
)
->
orbit
.
Controller
:
"""Builds a Orbit controler."""
train_actions
=
[]
if
not
train_actions
else
train_actions
if
trainer
:
...
...
@@ -223,6 +230,7 @@ class OrbitExperimentRunner:
global_step
=
self
.
trainer
.
global_step
,
steps_per_loop
=
self
.
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
self
.
checkpoint_manager
,
enable_async_checkpointing
=
enable_async_checkpointing
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
...
...
@@ -309,6 +317,7 @@ def run_experiment(
controller_cls
=
orbit
.
Controller
,
summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
eval_summary_manager
:
Optional
[
orbit
.
utils
.
SummaryManager
]
=
None
,
enable_async_checkpointing
:
bool
=
False
,
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
...
...
@@ -332,6 +341,8 @@ def run_experiment(
manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
Returns:
A 2-tuple of (model, eval_logs).
...
...
@@ -353,5 +364,6 @@ def run_experiment(
controller_cls
=
controller_cls
,
summary_manager
=
summary_manager
,
eval_summary_manager
=
eval_summary_manager
,
enable_async_checkpointing
=
enable_async_checkpointing
,
)
return
runner
.
run
()
official/nlp/train.py
浏览文件 @
6e2a1d5e
...
...
@@ -38,6 +38,11 @@ flags.DEFINE_integer(
default
=
None
,
help
=
'The number of total training steps for the pretraining job.'
)
flags
.
DEFINE_bool
(
'enable_async_checkpointing'
,
default
=
True
,
help
=
'A boolean indicating whether to enable async checkpoint saving'
)
def
_run_experiment_with_preemption_recovery
(
params
,
model_dir
):
"""Runs experiment and tries to reconnect when encounting a preemption."""
...
...
@@ -62,7 +67,8 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
model_dir
=
model_dir
,
enable_async_checkpointing
=
FLAGS
.
enable_async_checkpointing
)
keep_training
=
False
except
tf
.
errors
.
OpError
as
e
:
...
...
official/vision/train.py
浏览文件 @
6e2a1d5e
...
...
@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_bool
(
'enable_async_checkpointing'
,
default
=
True
,
help
=
'A boolean indicating whether to enable async checkpoint saving'
)
def
_run_experiment_with_preemption_recovery
(
params
,
model_dir
):
"""Runs experiment and tries to reconnect when encounting a preemption."""
...
...
@@ -60,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
eval_summary_manager
=
summary_manager
.
maybe_build_eval_summary_manager
(
params
=
params
,
model_dir
=
model_dir
),
enable_async_checkpointing
=
FLAGS
.
enable_async_checkpointing
,
)
keep_training
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录