提交 b88bbbab 编写于 作者: Y Yanbo Liang 提交者: François Chollet

Numpy ndarray should be serialized as Python list. (#10727)

* Numpy ndarray should be serialized as Python list.

* Add test case.

* Fix annotation
上级 75114fee
......@@ -79,8 +79,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
# if obj is any numpy type
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray):
return {'type': type(obj),
'value': obj.tolist()}
return obj.tolist()
else:
return obj.item()
......
......@@ -12,6 +12,7 @@ from keras.models import Model, Sequential
from keras.layers import Dense, Lambda, RepeatVector, TimeDistributed, Bidirectional, GRU, LSTM, CuDNNGRU, CuDNNLSTM
from keras.layers import Conv2D, Flatten
from keras.layers import Input, InputLayer
from keras.initializers import Constant
from keras import optimizers
from keras import losses
from keras import metrics
......@@ -641,6 +642,21 @@ def test_saving_recurrent_layer_without_bias():
os.remove(fname)
@keras_test
def test_saving_constant_initializer_with_numpy():
"""Test saving and loading model of constant initializer with numpy ndarray as input.
"""
model = Sequential()
model.add(Dense(2, input_shape=(3,), kernel_initializer=Constant(np.ones((3, 2)))))
model.add(Dense(3))
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
_, fname = tempfile.mkstemp('.h5')
save_model(model, fname)
model = load_model(fname)
os.remove(fname)
@keras_test
@pytest.mark.parametrize('implementation', [1, 2], ids=['impl1', 'impl2'])
@pytest.mark.parametrize('bidirectional', [False, True], ids=['single', 'bidirectional'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册