提交 d73b5f03 编写于 作者: M Matt Watson 提交者: François Chollet

Fix serialization of hashing layer

We would switch hash types when saving and loading because of a config
error.

PiperOrigin-RevId: 382784449
上级 85029239
......@@ -23,9 +23,6 @@ from keras.engine import base_layer
from keras.engine import base_preprocessing_layer
from tensorflow.python.util.tf_export import keras_export
# Default key from tf.sparse.cross_hashed
_DEFAULT_SALT_KEY = [0xDECAFCAFFE, 0xDECAFCAFFE]
@keras_export('keras.layers.Hashing',
'keras.layers.experimental.preprocessing.Hashing')
......@@ -128,11 +125,12 @@ class Hashing(base_layer.Layer):
def __init__(self, num_bins, mask_value=None, salt=None, **kwargs):
if num_bins is None or num_bins <= 0:
raise ValueError('`num_bins` cannot be `None` or non-positive values.')
super(Hashing, self).__init__(**kwargs)
super().__init__(**kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell('Hashing').set(True)
self.num_bins = num_bins
self.mask_value = mask_value
self.strong_hash = True if salt is not None else False
self.salt = None
if salt is not None:
if isinstance(salt, (tuple, list)) and len(salt) == 2:
self.salt = salt
......@@ -141,16 +139,10 @@ class Hashing(base_layer.Layer):
else:
raise ValueError('`salt can only be a tuple of size 2 integers, or a '
'single integer, given {}'.format(salt))
else:
self.salt = _DEFAULT_SALT_KEY
def _preprocess_input(self, inp):
if isinstance(inp, (list, tuple, np.ndarray)):
inp = tf.convert_to_tensor(inp)
return inp
def call(self, inputs):
inputs = self._preprocess_input(inputs)
if isinstance(inputs, (list, tuple, np.ndarray)):
inputs = tf.convert_to_tensor(inputs)
if isinstance(inputs, tf.SparseTensor):
return tf.SparseTensor(
indices=inputs.indices,
......@@ -199,10 +191,10 @@ class Hashing(base_layer.Layer):
return tf.TensorSpec(shape=output_shape, dtype=output_dtype)
def get_config(self):
config = {
config = super().get_config()
config.update({
'num_bins': self.num_bins,
'salt': self.salt,
'mask_value': self.mask_value,
}
base_config = super(Hashing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
})
return config
......@@ -14,16 +14,17 @@
# ==============================================================================
"""Tests for hashing layer."""
import os
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
import numpy as np
import keras
from keras import keras_parameterized
from keras import testing_utils
from keras.engine import input_layer
from keras.engine import training
from keras.layers.preprocessing import hashing
import numpy as np
import tensorflow.compat.v2 as tf
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
......@@ -193,8 +194,7 @@ class HashingTest(keras_parameterized.TestCase):
def test_hash_ragged_int_input_farmhash(self):
layer = hashing.Hashing(num_bins=3)
inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]],
dtype=tf.int64)
inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]], dtype=tf.int64)
out_data = layer(inp_data)
# Same hashed output as test_hash_sparse_input_farmhash
expected_output = [[1, 0, 0, 2], [1, 0, 1]]
......@@ -231,8 +231,7 @@ class HashingTest(keras_parameterized.TestCase):
def test_hash_ragged_int_input_siphash(self):
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]],
dtype=tf.int64)
inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]], dtype=tf.int64)
out_data = layer(inp_data)
# Same hashed output as test_hash_sparse_input_farmhash
expected_output = [[1, 1, 0, 1], [2, 1, 1]]
......@@ -270,6 +269,27 @@ class HashingTest(keras_parameterized.TestCase):
layer_1 = hashing.Hashing.from_config(config)
self.assertEqual(layer_1.name, layer.name)
def test_saved_model(self):
input_data = np.array(['omar', 'stringer', 'marlo', 'wire', 'skywalker'])
inputs = keras.Input(shape=(None,), dtype=tf.string)
outputs = hashing.Hashing(num_bins=100)(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)
original_output_data = model(input_data)
# Save the model to disk.
output_path = os.path.join(self.get_temp_dir(), 'tf_keras_saved_model')
model.save(output_path, save_format='tf')
loaded_model = keras.models.load_model(output_path)
# Ensure that the loaded model is unique (so that the save/load is real)
self.assertIsNot(model, loaded_model)
# Validate correctness of the new model.
new_output_data = loaded_model(input_data)
self.assertAllClose(new_output_data, original_output_data)
@parameterized.named_parameters(
(
'list_input',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册