提交 6859a01c 编写于 作者: M Monica Song 提交者: TensorFlower Gardener

Enable the Functional class' from_config method to be used for subclassed...

Enable the Functional class' from_config method to be used for subclassed models with Functional constructors.

PiperOrigin-RevId: 396047985
上级 6cf8b349
......@@ -671,14 +671,35 @@ class Functional(training_lib.Model):
Raises:
ValueError: In case of improperly formatted config dict.
TypeError: In case the config does match the cls constructor.
"""
with generic_utils.SharedObjectLoadingScope():
input_tensors, output_tensors, created_layers = reconstruct_from_config(
config, custom_objects)
model = cls(inputs=input_tensors, outputs=output_tensors,
name=config.get('name'))
connect_ancillary_layers(model, created_layers)
return model
if all(key in config for key in [
'name', 'layers', 'input_layers', 'output_layers']):
input_tensors, output_tensors, created_layers = reconstruct_from_config(
config, custom_objects)
model = cls(
inputs=input_tensors,
outputs=output_tensors,
name=config.get('name'))
connect_ancillary_layers(model, created_layers)
return model
# The config does not contain all the information necessary to revive a
# Functional model. This happens when the user creates subclassed models
# with a Functional constructor and has overriden the `get_config` method
# to return a completely new dictionary.
try:
return cls(**config)
except TypeError as e:
raise TypeError('Unable to revive model from config. When overriding '
'the `get_config`, make sure that the returned config '
'contains all items used as arguments in the '
f'constructor to {cls}, which is the default behavior. '
'You can override this default behavior by defining a '
'`from_config` method to specify how to create an '
f'instance of {cls.__name__} from the config. \n\n'
f'Error encountered during deserialization:\n{e}')
def _validate_graph_inputs_and_outputs(self):
"""Validates the inputs and outputs of a Graph Network."""
......
......@@ -21,7 +21,6 @@ SavedModel have the expected structure.
import tensorflow.compat.v2 as tf
# TODO(kathywu): Move relevant tests from saved_model_test to
import shutil
from absl.testing import parameterized
......@@ -174,6 +173,26 @@ class UnregisteredCustomSequentialModel(keras.Sequential):
self.add(keras.layers.InputLayer(input_shape=(2, 3)))
class FunctionalSubclassModel(keras.Model):
def __init__(self, units):
self.units = units
my_input = keras.Input(shape=(2, 3), name='inputs')
dense = keras.layers.Dense(self.units, activation='relu', name='dense')
output = dense(my_input)
outputs = {'output': output}
super().__init__(inputs=[my_input], outputs=outputs)
def get_config(self):
return {'units': self.units}
class FunctionalSubclassModelWrongConfig(FunctionalSubclassModel):
def get_config(self):
return {}
class ReviveTestBase(keras_parameterized.TestCase):
def setUp(self):
......@@ -346,6 +365,18 @@ class TestModelRevive(ReviveTestBase):
revived = keras_load.load(self.path, compile=False)
self._assert_revived_correctness(model, revived)
def test_functional_subclass(self):
model = FunctionalSubclassModel(32)
model.save(self.path, save_format='tf')
revived = keras_load.load(self.path, compile=False)
self._assert_revived_correctness(model, revived)
def test_functional_subclass_wrong_config(self):
model = FunctionalSubclassModelWrongConfig(32)
model.save(self.path, save_format='tf')
with self.assertRaisesRegex(TypeError, 'Unable to revive model'):
keras_load.load(self.path, compile=False)
def test_load_compiled_metrics(self):
model = testing_utils.get_small_sequential_mlp(1, 3)
......@@ -385,6 +416,8 @@ if __name__ == '__main__':
'CustomLayerWithConfig': CustomLayerWithConfig,
'CustomNetworkWithConfig': CustomNetworkWithConfig,
'CustomNetworkWithConfigName': CustomNetworkWithConfigName,
'SubclassedModelWithConfig': SubclassedModelWithConfig
'SubclassedModelWithConfig': SubclassedModelWithConfig,
'FunctionalSubclassModel': FunctionalSubclassModel,
'FunctionalSubclassModelWrongConfig': FunctionalSubclassModelWrongConfig
}):
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册