提交 d2e24b60 编写于 作者: Y Yifei Feng 提交者: Gunhan Gulsoy

Don't assign device for the keras part of _saved_first_checkpoint. Fix #14504. (#17231)

PiperOrigin-RevId: 186526175
上级 0f52f44b
......@@ -221,18 +221,18 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects,
Returns:
The model_fn for a keras Estimator.
"""
with ops.Graph().as_default() as g, g.device(estimator._device_fn):
random_seed.set_random_seed(estimator.config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
if isinstance(model, models.Sequential):
model = model.model
# Load weights and save to checkpoint if there is no checkpoint
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
if not latest_path:
with session.Session() as sess:
# Load weights and save to checkpoint if there is no checkpoint
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
if not latest_path:
with ops.Graph().as_default():
random_seed.set_random_seed(estimator.config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
if isinstance(model, models.Sequential):
model = model.model
# save to checkpoint
with session.Session(config=estimator._session_config) as sess:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
if not model.train_function:
......
......@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from math import log10
import os
import tempfile
......@@ -62,7 +63,7 @@ def simple_functional_model():
return model
def get_resource_for_simple_model(is_sequential, is_evaluate):
def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
model = simple_sequential_model(
) if is_sequential else simple_functional_model()
if is_sequential:
......@@ -352,6 +353,30 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
model_dir=tempfile.mkdtemp(dir=self._base_dir),
custom_objects=custom_objects)
def test_tf_config(self):
keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
tf_config = json.dumps({
'cluster': {
run_config_lib.TaskType.PS: ['localhost:1234'],
run_config_lib.TaskType.WORKER: ['localhost:1236'],
run_config_lib.TaskType.MASTER: ['localhost:1238']
},
'task': {
'type': run_config_lib.TaskType.MASTER,
'index': 0
}
})
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
with self.test_session():
keras.estimator.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册