Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
d2e24b60
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d2e24b60
编写于
2月 28, 2018
作者:
Y
Yifei Feng
提交者:
Gunhan Gulsoy
2月 28, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Don't assign device for the keras part of _saved_first_checkpoint. Fix #14504. (#17231)
PiperOrigin-RevId: 186526175
上级
0f52f44b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
38 addition
and
13 deletion
+38
-13
tensorflow/python/keras/_impl/keras/estimator.py
tensorflow/python/keras/_impl/keras/estimator.py
+12
-12
tensorflow/python/keras/_impl/keras/estimator_test.py
tensorflow/python/keras/_impl/keras/estimator_test.py
+26
-1
未找到文件。
tensorflow/python/keras/_impl/keras/estimator.py
浏览文件 @
d2e24b60
...
...
@@ -221,18 +221,18 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects,
Returns:
The model_fn for a keras Estimator.
"""
with
ops
.
Graph
().
as_default
()
as
g
,
g
.
device
(
estimator
.
_device_fn
):
random_seed
.
set_random_seed
(
estimator
.
config
.
tf_random_seed
)
training_util
.
create_global_step
()
model
=
_clone_and_build_model
(
model_fn_lib
.
ModeKeys
.
TRAIN
,
keras_model
,
custom_objects
)
if
isinstance
(
model
,
models
.
Sequential
):
model
=
model
.
model
# Load weights and save to checkpoint if there is no checkpoint
latest_path
=
saver_lib
.
latest_checkpoint
(
estimator
.
model_dir
)
if
not
latest_path
:
with
session
.
Session
()
as
sess
:
# Load weights and save to checkpoint if there is no checkpoint
latest_path
=
saver_lib
.
latest_checkpoint
(
estimator
.
model_dir
)
if
not
latest_path
:
with
ops
.
Graph
().
as_default
():
random_seed
.
set_random_seed
(
estimator
.
config
.
tf_random_seed
)
training_util
.
create_global_step
()
model
=
_clone_and_build_model
(
model_fn_lib
.
ModeKeys
.
TRAIN
,
keras_model
,
custom_objects
)
if
isinstance
(
model
,
models
.
Sequential
):
model
=
model
.
model
# save to checkpoint
with
session
.
Session
(
config
=
estimator
.
_session_config
)
as
sess
:
model
.
set_weights
(
keras_weights
)
# Make update ops and initialize all variables.
if
not
model
.
train_function
:
...
...
tensorflow/python/keras/_impl/keras/estimator_test.py
浏览文件 @
d2e24b60
...
...
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
from
math
import
log10
import
os
import
tempfile
...
...
@@ -62,7 +63,7 @@ def simple_functional_model():
return
model
def
get_resource_for_simple_model
(
is_sequential
,
is_evaluat
e
):
def
get_resource_for_simple_model
(
is_sequential
=
True
,
is_evaluate
=
Fals
e
):
model
=
simple_sequential_model
(
)
if
is_sequential
else
simple_functional_model
()
if
is_sequential
:
...
...
@@ -352,6 +353,30 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
model_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
_base_dir
),
custom_objects
=
custom_objects
)
def
test_tf_config
(
self
):
keras_model
,
(
_
,
_
),
(
_
,
_
),
_
,
_
=
get_resource_for_simple_model
()
keras_model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
'rmsprop'
,
metrics
=
[
'mse'
,
keras
.
metrics
.
categorical_accuracy
])
tf_config
=
json
.
dumps
({
'cluster'
:
{
run_config_lib
.
TaskType
.
PS
:
[
'localhost:1234'
],
run_config_lib
.
TaskType
.
WORKER
:
[
'localhost:1236'
],
run_config_lib
.
TaskType
.
MASTER
:
[
'localhost:1238'
]
},
'task'
:
{
'type'
:
run_config_lib
.
TaskType
.
MASTER
,
'index'
:
0
}
})
with
test
.
mock
.
patch
.
dict
(
'os.environ'
,
{
'TF_CONFIG'
:
tf_config
}):
with
self
.
test_session
():
keras
.
estimator
.
model_to_estimator
(
keras_model
=
keras_model
,
model_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
_base_dir
))
if
__name__
==
'__main__'
:
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录