提交 a77c3f73 编写于 作者: M Matt Watson 提交者: TensorFlower Gardener

Fix lookup layer oov token check when num_oov_indices > len(vocabulary tokens)

This was always broken for numpy vocabulary inputs, and recently broke for lists as well.

PiperOrigin-RevId: 380946490
上级 586c4ad6
......@@ -411,11 +411,7 @@ class IndexLookup(base_preprocessing_layer.PreprocessingLayer):
should_have_oov = (self.num_oov_indices > 0)
expected_oov = [self.oov_token] * self.num_oov_indices
found_oov = vocabulary[oov_start:token_start]
has_oov = should_have_oov and found_oov == expected_oov
# If we get a numpy array, then has_oov may end up being a numpy array
# instead of a bool. Fix this by collapsing the variable if it's not bool.
if not isinstance(has_oov, bool):
has_oov = any(has_oov)
has_oov = should_have_oov and np.array_equal(found_oov, expected_oov)
if all([should_have_mask, has_mask, should_have_oov]) and not has_oov:
raise ValueError(
......
......@@ -1535,6 +1535,31 @@ class IndexLookupVocabularyTest(keras_parameterized.TestCase,
self.assertAllEqual(returned_vocab, ["wind", "and", "fire"])
self.assertAllEqual(layer.vocabulary_size(), 5)
def test_vocab_multi_oov(self):
vocab_data = ["", "[OOV]", "[OOV]", "wind", "and", "fire"]
layer = index_lookup.IndexLookup(
max_tokens=None,
num_oov_indices=2,
mask_token="",
oov_token="[OOV]",
dtype=tf.string)
layer.set_vocabulary(vocab_data)
returned_vocab = layer.get_vocabulary()
self.assertAllEqual(returned_vocab, vocab_data)
def test_vocab_multi_oov_not_present(self):
vocab_data = ["wind", "and", "fire"]
layer = index_lookup.IndexLookup(
max_tokens=None,
num_oov_indices=10,
mask_token="",
oov_token="[OOV]",
dtype=tf.string)
layer.set_vocabulary(vocab_data)
returned_vocab = layer.get_vocabulary()
self.assertAllEqual(returned_vocab,
[""] + ["[OOV]"] * 10 + ["wind", "and", "fire"])
def test_vocab_with_max_cap(self):
vocab_data = ["", "[OOV]", "wind", "and", "fire"]
layer = index_lookup.IndexLookup(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册