提交 5a24ebbb 编写于 作者: F Fariz Rahman 提交者: François Chollet

Bug fix: model save when file already exists (#11289)

* use mode arg

* add loop test

* add stricter condition

* make dict behaviour similar
上级 a07253d8
......@@ -183,11 +183,13 @@ class H5Dict(object):
self.data = path
self._is_file = False
elif isinstance(path, str):
self.data = h5py.File(path,)
self.data = h5py.File(path, mode=mode)
self._is_file = True
elif isinstance(path, dict):
self.data = path
self._is_file = False
if mode == 'w':
self.data.clear()
# Flag to check if a dict is user defined data or a sub group:
self.data['_is_group'] = True
else:
......@@ -209,7 +211,7 @@ class H5Dict(object):
else:
self.data[attr] = val
return
if attr in self:
if isinstance(self.data, h5py.Group) and attr in self.data:
raise KeyError('Cannot set attribute. '
'Group with name "{}" exists.'.format(attr))
if is_np:
......
......@@ -632,6 +632,29 @@ def test_saving_recurrent_layer_without_bias():
os.remove(fname)
def test_loop_model_saving():
model = Sequential()
model.add(Dense(2, input_shape=(3,)))
model.compile(loss=losses.MSE,
optimizer=optimizers.RMSprop(lr=0.0001),
metrics=[metrics.categorical_accuracy])
x = np.random.random((1, 3))
y = np.random.random((1, 2))
_, fname = tempfile.mkstemp('.h5')
for _ in range(3):
model.train_on_batch(x, y)
save_model(model, fname, overwrite=True)
out = model.predict(x)
new_model = load_model(fname)
os.remove(fname)
out2 = new_model.predict(x)
assert_allclose(out, out2, atol=1e-05)
def test_saving_constant_initializer_with_numpy():
"""Test saving and loading model of constant initializer with numpy inputs.
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册