提交 868f22ad 编写于 作者: P Pavithra Vijay 提交者: TensorFlower Gardener

Deserializing loss class in hdf5 format.

PiperOrigin-RevId: 251579891
上级 69b121d7
......@@ -26,6 +26,7 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.saving import model_config as model_config_lib
from tensorflow.python.keras.saving import saving_utils
......@@ -202,7 +203,10 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
optimizer_config, custom_objects=custom_objects)
# Recover loss functions and metrics.
loss = convert_custom_objects(training_config['loss'])
loss_config = training_config['loss'] # Deserialize loss class.
if isinstance(loss_config, dict) and 'class_name' in loss_config:
loss_config = losses.get(loss_config)
loss = convert_custom_objects(loss_config)
metrics = convert_custom_objects(training_config['metrics'])
weighted_metrics = convert_custom_objects(
training_config.get('weighted_metrics', None))
......
......@@ -662,7 +662,8 @@ class TestWholeModelSaving(test.TestCase):
for i in range(4):
f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)
model = keras.Model(inputs=[x], outputs=[f])
model.compile(loss='mse', optimizer='adam', metrics=['acc'])
model.compile(
'adam', loss=keras.losses.MeanSquaredError(), metrics=['acc'])
x = np.random.random((1, 2))
y = np.random.random((1, 2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册