未验证 提交 20cc2190 编写于 作者: P pyoung2778 提交者: GitHub

Check in seq_flow_lite (#10750)

上级 fdecf385
......@@ -16,14 +16,11 @@ http_archive(
http_archive(
name = "org_tensorflow",
sha256 = "40d3203ab5f246d83bae328288a24209a2b85794f1b3e2cd0329458d8e7c1985",
strip_prefix = "tensorflow-2.6.0",
urls = [
"https://github.com/tensorflow/tensorflow/archive/v2.6.0.zip",
],
strip_prefix = "tensorflow-2.9.1",
sha256 = "9f2dac244e5af6c6a13a7dad6481e390174ac989931942098e7a4373f1bccfc2",
urls = ["https://github.com/tensorflow/tensorflow/archive/v2.9.1.zip"],
)
http_archive(
name = "org_tflite_support",
strip_prefix = "tflite-support-0861599711ef31de58f62ed3ff6bbcc1e4817ef6",
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A tool to export TFLite model."""
import importlib
......@@ -22,7 +21,7 @@ import os
from absl import app
from absl import flags
import tensorflow.compat.v1 as tf
import tensorflow_text as tftext
from layers import base_layers # import seq_flow_lite module
from layers import projection_layers # import seq_flow_lite module
from utils import tflite_utils # import seq_flow_lite module
......@@ -48,25 +47,33 @@ def main(_):
with tf.Graph().as_default() as graph:
with tf.Session(graph=graph) as session:
text = tf.placeholder(tf.string, shape=[1], name="Input")
prxlayer = projection_layers.ProjectionLayer(model_config,
base_layers.TFLITE)
encoder = model.Encoder(model_config, base_layers.TFLITE)
projection, seq_lengh = prxlayer(text)
logits = encoder(projection, seq_lengh)
inputs = [text]
if "pqrnn" in runner_config["name"]:
prxlayer = projection_layers.ProjectionLayer(model_config,
base_layers.TFLITE)
encoder = model.Encoder(model_config, base_layers.TFLITE)
projection, seq_length = prxlayer(text)
logits = encoder(projection, seq_length)
else:
byte_int = tftext.ByteSplitter().split(text)
token_ids = tf.cast(byte_int, tf.int32).to_tensor()
token_ids = tf.reshape(token_ids, [1, -1])
token_ids += 3
encoder = model.Encoder(model_config, base_layers.TFLITE)
logits = encoder(token_ids, None)
if FLAGS.output == "logits":
outputs = logits
outputs = [logits]
elif FLAGS.output == "sigmoid":
outputs = tf.math.sigmoid(logits)
outputs = [tf.math.sigmoid(logits)]
else:
assert FLAGS.output == "softmax", "Unexpected output"
outputs = tf.nn.softmax(logits)
outputs = [tf.nn.softmax(logits)]
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
saver = tf.train.Saver()
saver.restore(session, tf.train.latest_checkpoint(FLAGS.output_dir))
tflite_fb = tflite_utils.generate_tflite(session, graph, [text],
[outputs])
tflite_fb = tflite_utils.generate_tflite(session, graph, inputs, outputs)
output_file_name = os.path.join(FLAGS.output_dir, "tflite.fb")
with tf.gfile.Open(output_file_name, "wb") as f:
f.write(tflite_fb)
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Methods related to input datasets and readers."""
import functools
......@@ -21,6 +20,7 @@ import sys
from absl import logging
import tensorflow as tf
from tensorflow import estimator as tf_estimator
import tensorflow_datasets as tfds
import tensorflow_text as tftext
......@@ -83,13 +83,13 @@ def create_input_fn(runner_config, mode, drop_remainder):
def _input_fn(params):
"""Method to be used for reading the data."""
assert mode != tf.estimator.ModeKeys.PREDICT
split = "train" if mode == tf.estimator.ModeKeys.TRAIN else "test"
assert mode != tf_estimator.ModeKeys.PREDICT
split = "train" if mode == tf_estimator.ModeKeys.TRAIN else "test"
ds = tfds.load(runner_config["dataset"], split=split)
ds = ds.batch(params["batch_size"], drop_remainder=drop_remainder)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
ds = ds.shuffle(buffer_size=100)
ds = ds.repeat(count=1 if mode == tf.estimator.ModeKeys.EVAL else None)
ds = ds.repeat(count=1 if mode == tf_estimator.ModeKeys.EVAL else None)
ds = ds.map(
functools.partial(_post_processor, batch_size=params["batch_size"]),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
......
......@@ -82,12 +82,12 @@ py_strict_library(
srcs = ["misc_layers.py"],
srcs_version = "PY3",
deps = [
# package tensorflow
":embedding_layers",
# package tensorflow
"//layers:base_layers", # sequence projection
"//layers:conv_layers",
"//layers:conv_layers", # sequence projection
"//layers:dense_layers", # sequence projection
"//layers:normalization_layers",
"//layers:normalization_layers", # sequence projection
"//layers:quantization_layers", # sequence projection
],
)
......@@ -112,8 +112,8 @@ py_strict_library(
srcs_version = "PY3",
deps = [
# package tensorflow
"//layers:base_layers",
"//layers:quantization_layers",
"//layers:base_layers", # sequence projection
"//layers:quantization_layers", # sequence projection
],
)
......@@ -124,11 +124,11 @@ py_strict_library(
deps = [
":embedding_layers",
# package tensorflow
"//layers:base_layers",
"//layers:dense_layers",
"//layers:normalization_layers",
"//layers:quantization_layers",
"//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
"//layers:base_layers", # sequence projection
"//layers:dense_layers", # sequence projection
"//layers:normalization_layers", # sequence projection
"//layers:quantization_layers", # sequence projection
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Base layer for building models trained with quantization."""
import tensorflow as tf
......@@ -57,7 +56,7 @@ class BaseLayer(tf.keras.layers.Layer):
def add_weight_wrapper(self, shape):
"""Return a weight variable for the given shape."""
if self.parameters.initializer is not None:
initializer = self.parameters.initializer
initializer = clone_initializer(self.parameters.initializer)
else:
initializer = tf.keras.initializers.GlorotUniform()
weight = self.add_weight(
......@@ -136,3 +135,9 @@ class BaseLayer(tf.keras.layers.Layer):
maxval=(1.0 - zero_probability),
dtype=tensor.dtype)
return tf.math.ceil(rnd)
def clone_initializer(initializer):
if isinstance(initializer, tf.keras.initializers.Initializer):
return initializer.__class__.from_config(initializer.get_config())
return initializer
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Base layer for convolution."""
import tensorflow as tf
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Basic dense layers."""
import tensorflow as tf
......@@ -30,6 +29,7 @@ class BaseQDense(base_layers.BaseLayer):
bias=True,
rank=2,
normalize=True,
quantize_output=True,
**kwargs):
self.units = units
self.rank = rank
......@@ -37,7 +37,9 @@ class BaseQDense(base_layers.BaseLayer):
self.activation = activation
self.bias = bias
self.normalize = normalize
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
self.quantize_output = quantize_output
if quantize_output:
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
self._create_normalizer(**kwargs)
super(BaseQDense, self).__init__(**kwargs)
......@@ -62,7 +64,10 @@ class BaseQDense(base_layers.BaseLayer):
outputs = normalize_method(outputs)
if self.activation:
outputs = self.activation(outputs)
return self.qoutput(outputs)
if self.quantize_output:
return self.qoutput(outputs)
else:
return outputs
def _dense_r34(self, inputs, normalize_method):
bsz = self.get_batch_dimension(inputs)
......
......@@ -15,8 +15,8 @@
"""Layers for embedding."""
import tensorflow as tf
from layers import base_layers
from layers import quantization_layers
from layers import base_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
class EmbeddingLayer(base_layers.BaseLayer):
......
......@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for embedding."""
import math
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
from layers import conv_layers
from layers import conv_layers # import seq_flow_lite module
from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers
from layers import embedding_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for normalization."""
import tensorflow as tf
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for QRNN."""
import tensorflow as tf
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for quantization."""
import tensorflow as tf
......
......@@ -16,12 +16,12 @@
# pylint: disable=arguments-renamed
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import normalization_layers
from layers import quantization_layers
from tf_ops import tf_custom_ops_py
from layers import base_layers # import seq_flow_lite module
from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers # import seq_flow_lite module
from layers import normalization_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
from tf_ops import tf_custom_ops_py # import seq_flow_lite module
class SelfAttention(base_layers.BaseLayer):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Metric functions."""
import tensorflow.compat.v1 as tf
......
......@@ -45,13 +45,13 @@ py_library(
srcs_version = "PY3",
deps = [
# package tensorflow
"//layers:base_layers",
"//layers:dense_layers",
"//layers:embedding_layers",
"//layers:misc_layers",
"//layers:qrnn_layers",
# //tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
"//layers:base_layers", # sequence projection
"//layers:dense_layers", # sequence projection
"//layers:embedding_layers", # sequence projection
"//layers:misc_layers", # sequence projection
"//layers:qrnn_layers", # sequence projection
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
......@@ -62,13 +62,13 @@ py_library(
deps = [
":transformer_encoder",
# package tensorflow
"//layers:base_layers",
"//layers:embedding_layers",
"//layers:misc_layers",
"//layers:normalization_layers",
"//layers:quantization_layers",
# "//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
"//layers:base_layers", # sequence projection
"//layers:embedding_layers", # sequence projection
"//layers:misc_layers", # sequence projection
"//layers:normalization_layers", # sequence projection
"//layers:quantization_layers", # sequence projection
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
......@@ -79,11 +79,11 @@ py_library(
deps = [
# package absl/logging
# package tensorflow
"//layers:base_layers",
"//layers:embedding_layers",
"//layers:transformer_layers",
# "//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
"//layers:base_layers", # sequence projection
"//layers:embedding_layers", # sequence projection
"//layers:transformer_layers", # sequence projection
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
......@@ -93,13 +93,13 @@ py_library(
srcs_version = "PY3",
deps = [
# package absl/logging
# package tensor2tensor/utils:beam_search
# package tensorflow
# tensor2tensor/utils:beam_search",
"//layers:base_layers",
"//layers:embedding_layers",
"//layers:misc_layers",
"//layers:transformer_layers",
"//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
"//layers:base_layers", # sequence projection
"//layers:embedding_layers", # sequence projection
"//layers:misc_layers", # sequence projection
"//layers:transformer_layers", # sequence projection
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
......@@ -33,11 +33,11 @@ Sample model params:
from absl import logging
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import misc_layers
from layers import qrnn_layers
from layers import base_layers # import seq_flow_lite module
from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers # import seq_flow_lite module
from layers import misc_layers # import seq_flow_lite module
from layers import qrnn_layers # import seq_flow_lite module
class Encoder(tf.keras.layers.Layer):
......
......@@ -16,13 +16,13 @@
from absl import logging
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import misc_layers
from layers import normalization_layers
from layers import quantization_layers
from models import transformer_encoder
from layers import base_layers # import seq_flow_lite module
from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers # import seq_flow_lite module
from layers import misc_layers # import seq_flow_lite module
from layers import normalization_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
from models import transformer_encoder # import seq_flow_lite module
class Encoder(tf.keras.layers.Layer):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Implementation of pQRNN model."""
from absl import logging
......@@ -43,7 +42,7 @@ class Encoder(tf.keras.layers.Layer):
_get_params("qrnn_kernel_width", 3)
_get_params("qrnn_zoneout_probability")
_get_params("number_qrnn_layers")
_get_params("labels")
_get_params("labels", [])
_get_params("regularizer_scale")
_get_params("quantize")
......@@ -66,11 +65,12 @@ class Encoder(tf.keras.layers.Layer):
self.attention_pool = misc_layers.AttentionPooling(
parameters=self.parameters)
self.final_fc = dense_layers.BaseQDense(
units=self.num_classes,
rank=2,
parameters=self.parameters,
activation=None)
if self.num_classes:
self.final_fc = dense_layers.BaseQDense(
units=self.num_classes,
rank=2,
parameters=self.parameters,
activation=None)
def call(self, projection, seq_length):
mask = tf.sequence_mask(
......@@ -82,7 +82,11 @@ class Encoder(tf.keras.layers.Layer):
bottleneck = self.bottleneck_layer(projection, maskr3, inverse_normalizer)
outputs = self.qrnn_stack(bottleneck, maskr3, inverse_normalizer)
pre_logits = self.attention_pool(outputs, maskr3, inverse_normalizer)
return self.final_fc(pre_logits)
if self.num_classes:
return self.final_fc(pre_logits)
else:
return pre_logits
class Model(Encoder):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Implementation of PRADO model."""
import copy
......
......@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for seq_flow_lite.sgnn."""
import tensorflow as tf
......
......@@ -18,8 +18,8 @@
from absl import logging
import tensorflow as tf
from layers import base_layers
from layers import transformer_layers
from layers import base_layers # import seq_flow_lite module
from layers import transformer_layers # import seq_flow_lite module
class Model(tf.keras.layers.Layer):
......
......@@ -20,12 +20,12 @@ from absl import logging
from tensor2tensor.utils import beam_search
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import normalization_layers
from layers import quantization_layers
from layers import transformer_layers
from layers import base_layers # import seq_flow_lite module
from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers # import seq_flow_lite module
from layers import normalization_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
from layers import transformer_layers # import seq_flow_lite module
class TransformerUniformAttnDecoder(base_layers.BaseLayer):
......
......@@ -11,20 +11,23 @@ package(
)
cc_library(
name = "sequence_string_projection_op",
srcs = [
"sequence_string_projection.cc",
name = "projection_normalizer_util",
srcs = ["projection_normalizer_util.cc"],
hdrs = ["projection_normalizer_util.h"],
deps = [
":projection_util",
"@utf_archive//:utf",
],
)
cc_library(
name = "projection_tokenizer_util",
srcs = ["projection_tokenizer_util.cc"],
hdrs = ["projection_tokenizer_util.h"],
deps = [
":projection_normalizer_util",
":projection_tokenizer_util",
":projection_util",
":text_distorter",
"@com_google_absl//absl/container:flat_hash_map",
"@tensorflow_includes//:includes",
"@tensorflow_solib//:framework_lib",
"@utf_archive//:utf",
],
alwayslink = 1,
)
cc_library(
......@@ -37,22 +40,46 @@ cc_library(
)
cc_library(
name = "projection_tokenizer_util",
srcs = ["projection_tokenizer_util.cc"],
hdrs = ["projection_tokenizer_util.h"],
name = "skipgram_finder",
srcs = ["skipgram_finder.cc"],
hdrs = ["skipgram_finder.h"],
deps = [
":projection_util",
"@utf_archive//:utf",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@icu4c//:icu4c",
],
)
cc_test(
name = "skipgram_finder_test",
srcs = ["skipgram_finder_test.cc"],
deps = [
":skipgram_finder",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@icu4c//:icu4c",
],
)
cc_library(
name = "projection_normalizer_util",
srcs = ["projection_normalizer_util.cc"],
hdrs = ["projection_normalizer_util.h"],
name = "subsequence_finder",
srcs = ["subsequence_finder.cc"],
hdrs = ["subsequence_finder.h"],
deps = [
":projection_util",
"@utf_archive//:utf",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@icu4c//:icu4c",
],
)
cc_test(
name = "subsequence_finder_test",
srcs = ["subsequence_finder_test.cc"],
deps = [
":subsequence_finder",
"@com_google_googletest//:gtest_main",
],
)
......@@ -67,6 +94,55 @@ cc_library(
],
)
cc_library(
name = "denylist_op",
srcs = ["denylist_op.cc"],
deps = [
":skipgram_finder",
":subsequence_finder",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@tensorflow_includes//:includes",
"@tensorflow_solib//:framework_lib",
],
alwayslink = 1,
)
gen_op_wrapper_py(
name = "denylist_op_py",
out = "denylist_op.py",
kernel_lib = ":denylist_op",
)
py_test(
name = "denylist_op_py_test",
srcs = ["denylist_op_test.py"],
main = "denylist_op_test.py",
python_version = "PY3",
srcs_version = "PY3",
deps = [
":denylist_op_py",
],
)
cc_library(
name = "sequence_string_projection_op",
srcs = [
"sequence_string_projection.cc",
],
deps = [
":projection_normalizer_util",
":projection_tokenizer_util",
":projection_util",
":text_distorter",
"@com_google_absl//absl/container:flat_hash_map",
"@tensorflow_includes//:includes",
"@tensorflow_solib//:framework_lib",
],
alwayslink = 1,
)
cc_test(
name = "sequence_string_projection_test",
size = "small",
......@@ -78,6 +154,12 @@ cc_test(
],
)
gen_op_wrapper_py(
name = "sequence_string_projection_op_py",
out = "sequence_string_projection_op.py",
kernel_lib = ":sequence_string_projection_op",
)
cc_library(
name = "sequence_string_projection_op_v2",
srcs = [
......@@ -111,12 +193,6 @@ gen_op_wrapper_py(
kernel_lib = ":sequence_string_projection_op_v2",
)
gen_op_wrapper_py(
name = "sequence_string_projection_op_py",
out = "sequence_string_projection_op.py",
kernel_lib = ":sequence_string_projection_op",
)
cc_library(
name = "tf_custom_ops",
srcs = ["tf_custom_ops.cc"],
......
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
namespace seq_flow_lite {
using ::tensorflow::OpKernel;
using ::tensorflow::OpKernelConstruction;
using ::tensorflow::OpKernelContext;
using ::tensorflow::Status;
using ::tensorflow::Tensor;
using ::tensorflow::TensorShape;
using ::tensorflow::errors::InvalidArgument;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
// Description of the outputs and attributes for the Denylist ops.
const char kDescription[] = R"(
output: A floating point tensor that contains a prediction vector for each
input string. The vector will either be:
* [1, 1, ..., 0, 0, ...] if no denylisted skipgrams are found.
(All negative categories are 1.0 and all positive categories are 0.0.)
* an indicator vector if any denylisted skipgrams are found.
(0.0 if no skipgrams belonging to the category were found and 1.0 otherwise)
max_skip_size: The maximum number of tokens that can be skipped when generating
skipgrams.
denylist: A string vector containing denylisted skipgrams.
denylist_category: An int32 vector containing the category of the corresponding
skipgram in the denylist.
categories: An int32 scalar. This is the total number of categories.
All categories in denylist_category must be in [0, categories).
negative_categories: An int32 scalar. The total number of categories that
should be set if no entries in the denylist are triggered. These
negative categories are assumed to be [0, negative_categories).
)";
// The base class for all Denylist ops. It does two things:
// 1) It defines the output tensor of the op and it defines the attributes
// needed to specify the denylist and convert denylist categories into
// output vectors.
// 2) It defines a Compute() function. The compute function is responsible
// for filling in the output tensor, while the subclass is responsible
// for processing the input.
class DenylistOpBase : public OpKernel {
public:
explicit DenylistOpBase(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("categories", &categories_));
OP_REQUIRES_OK(context, context->GetAttr("negative_categories",
&negative_categories_));
OP_REQUIRES(context, categories_ > 0,
InvalidArgument("Number of categories (", categories_,
") must be positive."));
OP_REQUIRES(
context, negative_categories_ >= 0,
InvalidArgument("Number of negative_categories (", negative_categories_,
") must be non-negative."));
OP_REQUIRES(context, negative_categories_ < categories_,
InvalidArgument("Number of categories (", categories_,
") must be greater than the "
"number of negative_categories (",
negative_categories_, ")."));
OP_REQUIRES_OK(context, context->GetAttr("max_skip_size", &max_skip_size_));
OP_REQUIRES_OK(context, context->GetAttr("denylist", &denylist_));
OP_REQUIRES_OK(context,
context->GetAttr("denylist_category", &denylist_category_));
OP_REQUIRES(context, denylist_.size() == denylist_category_.size(),
InvalidArgument("denylist length (", denylist_.size(),
") != denylist_category length (",
denylist_category_.size(), ")"));
int max =
*std::max_element(denylist_category_.begin(), denylist_category_.end());
OP_REQUIRES(context, max < categories_,
InvalidArgument("max element of denylist_category (", max,
") >= categories (", categories_, ")"));
int min =
*std::min_element(denylist_category_.begin(), denylist_category_.end());
OP_REQUIRES(
context, min >= 0,
InvalidArgument("min element of denylist_category (", min, ") < 0"));
}
void Compute(OpKernelContext* context) override {
auto compute_context = InitializeComputeContext(context);
if (compute_context == nullptr) {
return;
}
auto context_cleaner = absl::MakeCleanup([this, compute_context] {
this->FinalizeComputeContext(compute_context);
});
Tensor* output_tensor;
TensorShape output_shape = InputStringsShape(compute_context);
output_shape.AddDim(categories_);
OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
&output_tensor));
auto output_values = output_tensor->flat<float>();
for (int i = 0; i < NumInputStrings(compute_context); i++) {
auto category = GetCategories(i, compute_context);
int base_index = i * categories_;
if (category.empty()) {
for (int j = 0; j < categories_; j++) {
output_values(base_index + j) = j < negative_categories_ ? 1.0 : 0.0;
}
} else {
for (int j = 0; j < categories_; j++) {
output_values(base_index + j) = category.contains(j) ? 1.0 : 0.0;
}
}
}
}
protected:
int max_skip_size() { return max_skip_size_; }
int denylist_size() { return denylist_.size(); }
const std::string& denylist(int i) { return denylist_[i]; }
int32_t denylist_category(int i) { return denylist_category_[i]; }
private:
// Called at the beginning of Compute(). This function should process
// the input and return a context object that can be used to identify
// the denylist categories of each input string.
virtual void* InitializeComputeContext(OpKernelContext* context) = 0;
// Called at the end of Compute(). Frees the context object.
virtual void FinalizeComputeContext(void* context) = 0;
// Returns the shape of the input tensor, if it only consisted of strings.
// If the input tensor is strings, this is the shape of the input tensor.
// If the input tensor is tokens, this is the shape of the input tensor,
// minus the innermost dimension.
virtual TensorShape InputStringsShape(void* context) = 0;
// Returns the number of strings in the input tensor.
virtual int NumInputStrings(void* context) = 0;
// Returns the denylist categories of the index-th string.
virtual absl::flat_hash_set<int> GetCategories(int index, void* context) = 0;
int32_t categories_;
int32_t negative_categories_;
int max_skip_size_;
std::vector<std::string> denylist_;
std::vector<int32_t> denylist_category_;
};
// A base class for Denylist ops that expect a string tensor input.
class StringDenylistOp : public DenylistOpBase {
public:
explicit StringDenylistOp(OpKernelConstruction* context)
: DenylistOpBase(context) {}
private:
void* InitializeComputeContext(OpKernelContext* context) override {
const Tensor* input_tensor;
auto status = context->input("input", &input_tensor);
if (!status.ok()) {
context->CtxFailureWithWarning(__FILE__, __LINE__, status);
return nullptr;
}
return new ComputeContext(input_tensor);
}
void FinalizeComputeContext(void* context) override {
delete static_cast<ComputeContext*>(context);
}
TensorShape InputStringsShape(void* context) override {
return static_cast<ComputeContext*>(context)->input_tensor->shape();
}
int NumInputStrings(void* context) override {
return static_cast<ComputeContext*>(context)->input_tensor_values.size();
}
absl::flat_hash_set<int> GetCategories(int index, void* context) override {
return FindTerms(
static_cast<ComputeContext*>(context)->input_tensor_values(index));
}
struct ComputeContext {
ComputeContext(const Tensor* input_tensor)
: input_tensor(input_tensor),
input_tensor_values(input_tensor->flat<::tensorflow::tstring>()) {}
const Tensor* input_tensor;
::tensorflow::TTypes<::tensorflow::tstring>::ConstFlat input_tensor_values;
};
// Returns the set of denylist categories for the input string.
virtual absl::flat_hash_set<int> FindTerms(const std::string& input) = 0;
};
// A denylist op that uses the SkipgramFinder on string inputs.
class SkipgramDenylistOp : public StringDenylistOp {
public:
explicit SkipgramDenylistOp(OpKernelConstruction* context)
: StringDenylistOp(context) {
skipgram_finder_ = std::make_unique<SkipgramFinder>(max_skip_size());
for (int i = 0; i < denylist_size(); i++) {
skipgram_finder_->AddSkipgram(denylist(i), denylist_category(i));
}
}
private:
absl::flat_hash_set<int> FindTerms(const std::string& input) override {
return skipgram_finder_->FindSkipgrams(input);
}
std::unique_ptr<SkipgramFinder> skipgram_finder_;
};
REGISTER_KERNEL_BUILDER(
Name("SkipgramDenylist").Device(::tensorflow::DEVICE_CPU),
SkipgramDenylistOp);
// Shape inference function for Denylist ops with string inputs.
Status StringDenylistShapeFn(InferenceContext* context) {
int32_t categories;
TF_RETURN_IF_ERROR(context->GetAttr("categories", &categories));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(context->Concatenate(
context->input(0), context->MakeShape({categories}), &output_shape));
context->set_output(0, output_shape);
return ::tensorflow::Status::OK();
}
REGISTER_OP("SkipgramDenylist")
.Input("input: string")
.Output("output: float")
.Attr("max_skip_size: int")
.Attr("denylist: list(string)")
.Attr("denylist_category: list(int)")
.Attr("categories: int")
.Attr("negative_categories: int")
.SetShapeFn(StringDenylistShapeFn)
.Doc(absl::StrCat("Generates dense prediction vectors for input strings "
"using a skipgram denylist.",
"\n\n", "input: A string tensor.", "\n\n", kDescription));
// A Denylist op that uses the SubsequenceFinder on string inputs.
class SubsequenceDenylistOp : public StringDenylistOp {
public:
explicit SubsequenceDenylistOp(OpKernelConstruction* context)
: StringDenylistOp(context) {
subsequence_finder_ = std::make_unique<SubsequenceFinder>(max_skip_size());
for (int i = 0; i < denylist_size(); i++) {
subsequence_finder_->AddSubsequence(denylist(i), denylist_category(i));
}
}
private:
absl::flat_hash_set<int> FindTerms(const std::string& input) override {
return subsequence_finder_->FindSubsequences(input);
}
std::unique_ptr<SubsequenceFinder> subsequence_finder_;
};
REGISTER_KERNEL_BUILDER(
Name("SubsequenceDenylist").Device(::tensorflow::DEVICE_CPU),
SubsequenceDenylistOp);
REGISTER_OP("SubsequenceDenylist")
.Input("input: string")
.Output("output: float")
.Attr("max_skip_size: int")
.Attr("denylist: list(string)")
.Attr("denylist_category: list(int)")
.Attr("categories: int")
.Attr("negative_categories: int")
.SetShapeFn(StringDenylistShapeFn)
.Doc(absl::StrCat("Generates dense prediction vectors for inputs using a "
"subsequence denylist.",
"\n\n", "input: A string tensor.", "\n\n", kDescription));
// A denylist op that uses the SkipgramFinder on tokenized string inputs.
// The inputs are a pair of tensors: a token tensor of type string and
// a token count tensor of type T.
template <typename T>
class TokenizedDenylistOp : public DenylistOpBase {
public:
explicit TokenizedDenylistOp(OpKernelConstruction* context)
: DenylistOpBase(context) {
skipgram_finder_ = std::make_unique<SkipgramFinder>(max_skip_size());
for (int i = 0; i < denylist_size(); i++) {
skipgram_finder_->AddSkipgram(denylist(i), denylist_category(i));
}
}
private:
void* InitializeComputeContext(OpKernelContext* context) override {
const Tensor* input_tensor;
{
auto status = context->input("input", &input_tensor);
if (!status.ok()) {
context->CtxFailureWithWarning(__FILE__, __LINE__, status);
return nullptr;
}
}
const Tensor* token_count_tensor;
{
auto status = context->input("token_count", &token_count_tensor);
if (!status.ok()) {
context->CtxFailureWithWarning(__FILE__, __LINE__, status);
return nullptr;
}
}
return new ComputeContext(input_tensor, token_count_tensor);
}
void FinalizeComputeContext(void* context) override {
delete static_cast<ComputeContext*>(context);
}
TensorShape InputStringsShape(void* context) override {
return static_cast<ComputeContext*>(context)->shape;
}
int NumInputStrings(void* context) override {
return static_cast<ComputeContext*>(context)->size;
}
absl::flat_hash_set<int> GetCategories(int index, void* x) override {
ComputeContext* context = static_cast<ComputeContext*>(x);
int64_t num_tokens = context->token_count_flat(index);
std::vector<absl::string_view> tokens;
tokens.reserve(num_tokens);
int64_t start = index * context->max_tokens;
for (int64_t i = start; i < start + num_tokens; i++) {
tokens.emplace_back(context->token_flat(i).data(),
context->token_flat(i).size());
}
return skipgram_finder_->FindSkipgrams(tokens);
}
struct ComputeContext {
ComputeContext(const Tensor* token_tensor, const Tensor* token_count_tensor)
: token_flat(token_tensor->flat<::tensorflow::tstring>()),
token_count_flat(token_count_tensor->flat<T>()) {
shape = token_tensor->shape();
max_tokens = shape.dim_size(shape.dims() - 1);
shape.RemoveLastDims(1);
size = 1;
for (int64_t i = 0; i < shape.dims(); i++) {
size = size * shape.dim_size(i);
}
}
const typename ::tensorflow::TTypes<::tensorflow::tstring>::ConstFlat
token_flat;
const typename ::tensorflow::TTypes<T>::ConstFlat token_count_flat;
TensorShape shape;
int64_t size;
int64_t max_tokens;
};
std::unique_ptr<SkipgramFinder> skipgram_finder_;
};
REGISTER_KERNEL_BUILDER(Name("TokenizedDenylist")
.Device(::tensorflow::DEVICE_CPU)
.TypeConstraint<int32_t>("Ttoken_count"),
TokenizedDenylistOp<int32_t>);
REGISTER_KERNEL_BUILDER(Name("TokenizedDenylist")
.Device(::tensorflow::DEVICE_CPU)
.TypeConstraint<int64_t>("Ttoken_count"),
TokenizedDenylistOp<int64_t>);
// Shape inference function for Denylist ops with tokenized string inputs.
Status TokenizedDenylistShapeFn(InferenceContext* context) {
int32_t categories;
TF_RETURN_IF_ERROR(context->GetAttr("categories", &categories));
ShapeHandle string_tensor_shape;
TF_RETURN_IF_ERROR(
context->Subshape(context->input(0), 0, -1, &string_tensor_shape));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(context->Concatenate(
string_tensor_shape, context->MakeShape({categories}), &output_shape));
context->set_output(0, output_shape);
return ::tensorflow::Status::OK();
}
REGISTER_OP("TokenizedDenylist")
.Input("input: string")
.Input("token_count: Ttoken_count")
.Output("output: float")
.Attr("max_skip_size: int")
.Attr("denylist: list(string)")
.Attr("denylist_category: list(int)")
.Attr("categories: int")
.Attr("negative_categories: int")
.Attr("Ttoken_count: {int32, int64}")
.SetShapeFn(TokenizedDenylistShapeFn)
.Doc(absl::StrCat("Generates dense prediction vectors for tokens using a "
"skipgram denylist.",
"\n\n", "input: A string tensor of tokens.", "\n\n",
kDescription));
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.proto.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace seq_flow_lite {
namespace {
using ::tensorflow::DT_FLOAT;
using ::tensorflow::DT_INT32;
using ::tensorflow::DT_INT64;
using ::tensorflow::DT_STRING;
using ::tensorflow::NodeDefBuilder;
using ::tensorflow::OpsTestBase;
using ::tensorflow::Tensor;
using ::tensorflow::TensorShape;
using ::tensorflow::errors::InvalidArgument;
using ::tensorflow::test::ExpectTensorEqual;
using ::tensorflow::test::FillValues;
class SkipgramDenylistOpTest : public OpsTestBase {};
TEST_F(SkipgramDenylistOpTest, Correct) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(TensorShape({2}),
{"q a q b q c q", "q a b q q c"});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(SkipgramDenylistOpTest, Prefix) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b.* c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(TensorShape({2}),
{"q a q bq q c q", "q a bq q q c"});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(SkipgramDenylistOpTest, ZeroCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 0)
.Attr("negative_categories", 0)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (0) must be positive."));
}
TEST_F(SkipgramDenylistOpTest, NegativeCategoriesLessThanZero) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", -1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument(
"Number of negative_categories (-1) must be non-negative."));
}
TEST_F(SkipgramDenylistOpTest, CategoriesEqualNegativeCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", 1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (1) must be greater than the "
"number of negative_categories (1)."));
}
class SubsequenceDenylistOpTest : public OpsTestBase {};
TEST_F(SubsequenceDenylistOpTest, Correct) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(TensorShape({2}),
{"qaqbqcq", "qabqqc"});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(SubsequenceDenylistOpTest, ZeroCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 0)
.Attr("negative_categories", 0)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (0) must be positive."));
}
TEST_F(SubsequenceDenylistOpTest, NegativeCategoriesLessThanZero) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", -1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument(
"Number of negative_categories (-1) must be non-negative."));
}
TEST_F(SubsequenceDenylistOpTest, CategoriesEqualNegativeCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", 1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (1) must be greater than the "
"number of negative_categories (1)."));
}
class TokenizedDenylistOpTest : public OpsTestBase {};
TEST_F(TokenizedDenylistOpTest, CorrectInt64TokenCount) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(
TensorShape({2, 7}), {"q", "a", "q", "b", "q", "c", "q", //
"q", "a", "b", "q", "q", "c", ""});
AddInputFromArray<int64_t>(TensorShape({2}), {7, 6});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(TokenizedDenylistOpTest, CorrectInt32TokenCount) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT32})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(
TensorShape({2, 7}), {"q", "a", "q", "b", "q", "c", "q", //
"q", "a", "b", "q", "q", "c", ""});
AddInputFromArray<int32_t>(TensorShape({2}), {7, 6});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(TokenizedDenylistOpTest, ZeroCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 0)
.Attr("negative_categories", 0)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (0) must be positive."));
}
TEST_F(TokenizedDenylistOpTest, NegativeCategoriesLessThanZero) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", -1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument(
"Number of negative_categories (-1) must be non-negative."));
}
TEST_F(TokenizedDenylistOpTest, CategoriesEqualNegativeCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", 1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (1) must be greater than the "
"number of negative_categories (1)."));
}
} // namespace
} // namespace seq_flow_lite
# Copyright 2022 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test denylist op and show example usage from python wrapper."""
import tensorflow as tf
from tf_ops import denylist_op # import seq_flow_lite module
class SkipgramDenylistTest(tf.test.TestCase):
def test_correct(self):
result = denylist_op.skipgram_denylist(
input=["q a q b q c q", "q a b q q c"],
max_skip_size=1,
denylist=["a b c"],
denylist_category=[1],
categories=2,
negative_categories=1)
self.assertAllEqual(result, [[0.0, 1.0], [1.0, 0.0]])
class SubsequenceDenylistTest(tf.test.TestCase):
def test_correct(self):
result = denylist_op.subsequence_denylist(
input=["qaqbqcq", "qabqqc"],
max_skip_size=1,
denylist=["a b c"],
denylist_category=[1],
categories=2,
negative_categories=1)
self.assertAllEqual(result, [[0.0, 1.0], [1.0, 0.0]])
class TokenizedDenylistTest(tf.test.TestCase):
def test_correct(self):
result = denylist_op.tokenized_denylist(
input=[["q", "a", "q", "b", "q", "c", "q"],
["q", "a", "b", "q", "q", "c", ""]],
token_count=[7, 6],
max_skip_size=1,
denylist=["a b c"],
denylist_category=[1],
categories=2,
negative_categories=1)
self.assertAllEqual(result, [[0.0, 1.0], [1.0, 0.0]])
if __name__ == "__main__":
tf.test.main()
......@@ -26,7 +26,7 @@ limitations under the License.
bool IsDigit(const std::string& text) {
Rune rune;
for (size_t i = 0; i < text.length();) {
const int bytes_read = chartorune(&rune, const_cast<char *>(text.data()));
const int bytes_read = chartorune(&rune, const_cast<char*>(text.data()));
if (rune == Runeerror || bytes_read == 0) break;
if (rune >= static_cast<Rune>('0') && rune <= static_cast<Rune>('9')) {
return true;
......@@ -98,6 +98,29 @@ std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
return token;
}
void NormalizeSpaces(std::string& input) {
// Whether to copy the next character if it's a space.
bool copy_space = false;
size_t j = 0;
for (size_t i = 0; i < input.length(); ++i) {
if (input[i] == ' ') {
if (!copy_space) continue;
copy_space = false;
} else {
copy_space = true;
}
if (j != i) {
input[j] = input[i];
}
++j;
}
if (j > 0 && input[j - 1] == ' ') {
--j;
}
input.resize(j);
}
void ProjectionNormalizer::InitializeSeparators(const std::string& separators) {
for (size_t i = 0; i < separators.length(); ++i) {
if (separators[i] != ' ') {
......@@ -150,9 +173,14 @@ std::string ProjectionNormalizer::Normalize(const char* input_ptr, size_t len,
normalized = ContractToken(normalized.data(), normalized.length(), 3);
}
if (normalize_spaces_) {
NormalizeSpaces(normalized);
}
if (!separators_.empty()) {
// Add space around separators_.
normalized = NormalizeInternal(normalized.data(), normalized.length());
}
return normalized;
}
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#include <string>
#include <unordered_set>
......@@ -24,14 +24,17 @@ limitations under the License.
// Normalizes the input with the given |separators| by adding a space before and
// after each separator. When |normalize_repetition| is true, it removes the
// repeated characters (except numbers) which consecutively appeared more than
// twice in a word.
// twice in a word. When |normalize_spaces| is true, it removes spaces from
// the beginning and ending of the input, as well as repeated spaces.
// Examples: arwwwww -> arww, good!!!!! -> good!!, hahaha => haha.
class ProjectionNormalizer {
public:
explicit ProjectionNormalizer(const std::string& separators,
bool normalize_repetition = false) {
bool normalize_repetition = false,
bool normalize_spaces = false)
: normalize_repetition_(normalize_repetition),
normalize_spaces_(normalize_spaces) {
InitializeSeparators(separators);
normalize_repetition_ = normalize_repetition;
}
// Normalizes the repeated characters (except numbers) which consecutively
......@@ -49,6 +52,7 @@ class ProjectionNormalizer {
std::unordered_set<char> separators_;
bool normalize_repetition_;
bool normalize_spaces_;
};
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#include <string>
#include <unordered_set>
......@@ -55,4 +55,4 @@ class ProjectionTokenizer {
std::unordered_set<char> separators_;
};
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_UTIL_H_
#include <memory>
#include <string>
#include <unordered_map>
......@@ -156,4 +156,4 @@ std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
std::string JoinPairsBySpace(std::vector<std::pair<const char*, size_t>> words);
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_UTIL_H_
......@@ -109,11 +109,14 @@ class SequenceStringProjectionOp : public OpKernel {
bool normalize_repetition;
OP_REQUIRES_OK(context, context->GetAttr("normalize_repetition",
&normalize_repetition));
bool normalize_spaces;
OP_REQUIRES_OK(context,
context->GetAttr("normalize_spaces", &normalize_spaces));
std::string separators;
OP_REQUIRES_OK(context, context->GetAttr("token_separators", &separators));
if (!separators.empty() || normalize_repetition) {
if (!separators.empty() || normalize_repetition || normalize_spaces) {
projection_normalizer_ = absl::make_unique<ProjectionNormalizer>(
separators, normalize_repetition);
separators, normalize_repetition, normalize_spaces);
}
OP_REQUIRES_OK(context, context->GetAttr("add_first_cap_feature",
......@@ -326,6 +329,7 @@ REGISTER_OP("SequenceStringProjection")
.Attr("split_on_space: bool = True")
.Attr("token_separators: string = ''")
.Attr("normalize_repetition: bool = false")
.Attr("normalize_spaces: bool = false")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle size;
......@@ -384,6 +388,10 @@ Attribute(s):
- add_all_caps_feature: Specifies the probability with which a feature to the
resulting projection tensor that helps discriminate if the input token is
ALLCAPS will be added.
- normalize_repetition: When true normalizes repetition in text tokens before
fingerprinting.
- normalize_spaces: When true strips leading and trailing spaces and removes
repeated spaces.
Output(s):
- projection: Floating point tensor with ternary values of shape
......
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <cctype>
#include <deque>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace seq_flow_lite {
namespace {
void PreprocessToken(std::string& token) {
char* s = const_cast<char*>(token.data());
int32_t size = token.size();
int32_t in = 0;
int32_t out = 0;
while (in < size) {
UChar32 c;
int32_t old_in = in;
U8_NEXT(s, in, size, c);
if (c < 0) {
break;
}
if (u_ispunct(c)) continue;
UChar32 cl = u_tolower(c);
// This is a hack, but there are exactly two unicode characters whose
// lowercase versions have longer UTF-8 encodings (0x23a to 0x2c65,
// 0x23e to 0x2c66). So, to avoid sizing issues, they're not lowercased.
if (U8_LENGTH(cl) > (in - old_in)) {
cl = c;
}
U8_APPEND_UNSAFE(s, out, cl);
}
size_t remaining = token.size() - in;
if (remaining > 0) {
memmove(s + out, s + in, remaining);
out += remaining;
}
token.resize(out);
}
} // namespace
void SkipgramFinder::AddSkipgram(absl::string_view skipgram, int category) {
std::vector<std::string> tokens = absl::StrSplit(skipgram, ' ');
// Store the skipgram in a trie-like structure that uses tokens as the
// edge labels, instead of characters. Each node represents a skipgram made
// from the tokens used to reach the node, and stores the categories the
// skipgram is associated with.
TrieNode* cur = &skipgram_trie_;
for (auto& token : tokens) {
if (absl::EndsWith(token, ".*")) {
token.resize(token.size() - 2);
PreprocessToken(token);
auto iter = cur->prefix_to_node.find(token);
if (iter != cur->prefix_to_node.end()) {
cur = &iter->second;
} else {
cur = &cur->prefix_to_node
.emplace(std::piecewise_construct,
std::forward_as_tuple(token), std::make_tuple<>())
.first->second;
}
continue;
}
PreprocessToken(token);
auto iter = cur->token_to_node.find(token);
if (iter != cur->token_to_node.end()) {
cur = &iter->second;
} else {
cur = &cur->token_to_node
.emplace(std::piecewise_construct,
std::forward_as_tuple(token), std::make_tuple<>())
.first->second;
}
}
cur->categories.insert(category);
}
absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
absl::string_view input) const {
std::vector<std::string> tokens = absl::StrSplit(input, ' ');
std::vector<absl::string_view> sv_tokens;
sv_tokens.reserve(tokens.size());
for (auto& token : tokens) {
PreprocessToken(token);
sv_tokens.emplace_back(token.data(), token.size());
}
return FindSkipgrams(sv_tokens);
}
absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
const std::vector<absl::string_view>& tokens) const {
absl::flat_hash_set<int> categories;
// Tracks skipgram prefixes and the index of their last token.
std::deque<std::pair<int, const TrieNode*>> indices_and_skipgrams;
for (int token_i = 0; token_i < tokens.size(); token_i++) {
const absl::string_view& token = tokens[token_i];
std::vector<absl::string_view> token_prefixes;
{
const char* s = token.data();
int32_t l = token.size();
int32_t n = 0;
while (n < l) {
int32_t n_old = n;
U8_FWD_1(s, n, l);
if (n == n_old) break;
token_prefixes.emplace_back(s, n);
}
}
// Drop any skipgrams prefixes which would skip more than `max_skip_size_`
// tokens between the end of the prefix and the current token.
while (!indices_and_skipgrams.empty()) {
if (indices_and_skipgrams.front().first + max_skip_size_ + 1 < token_i) {
indices_and_skipgrams.pop_front();
} else {
break;
}
}
// Check if we can form a valid skipgram prefix (or skipgram) by adding
// the current token to any of the existing skipgram prefixes, or
// if the current token is a valid skipgram prefix (or skipgram).
size_t size = indices_and_skipgrams.size();
for (size_t skipgram_i = 0; skipgram_i <= size; skipgram_i++) {
const auto& node = skipgram_i < size
? *indices_and_skipgrams[skipgram_i].second
: skipgram_trie_;
auto iter = node.token_to_node.find(token);
if (iter != node.token_to_node.end()) {
categories.insert(iter->second.categories.begin(),
iter->second.categories.end());
indices_and_skipgrams.push_back(std::make_pair(token_i, &iter->second));
}
for (auto token_prefix = token_prefixes.rbegin();
token_prefix != token_prefixes.rend(); token_prefix++) {
auto iter = node.prefix_to_node.find(*token_prefix);
if (iter != node.prefix_to_node.end()) {
categories.insert(iter->second.categories.begin(),
iter->second.categories.end());
indices_and_skipgrams.push_back(
std::make_pair(token_i, &iter->second));
}
}
}
}
return categories;
}
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
namespace seq_flow_lite {
// SkipgramFinder finds skipgrams in strings.
//
// To use: First, add skipgrams using AddSkipgram() - each skipgram is
// associated with some category. Then, call FindSkipgrams() on a string,
// which will return the set of categories of the skipgrams in the string.
//
// Both the skipgrams and the input strings will be tokenzied by splitting
// on spaces. Additionally, the tokens will be lowercased and have any
// trailing punctuation removed.
class SkipgramFinder {
public:
explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {}
// Adds a skipgram that SkipgramFinder should look for in input strings.
// Tokens may use the regex '.*' as a suffix.
void AddSkipgram(absl::string_view skipgram, int category);
// Find all of the skipgrams in `input`, and return their categories.
absl::flat_hash_set<int> FindSkipgrams(absl::string_view input) const;
// Find all of the skipgrams in `tokens`, and return their categories.
absl::flat_hash_set<int> FindSkipgrams(
const std::vector<absl::string_view>& tokens) const;
private:
struct TrieNode {
absl::flat_hash_set<int> categories;
// Maps tokens to the next node in the trie.
absl::flat_hash_map<std::string, TrieNode> token_to_node;
// Maps token prefixes (<prefix>.*) to the next node in the trie.
absl::flat_hash_map<std::string, TrieNode> prefix_to_node;
};
TrieNode skipgram_trie_;
int max_skip_size_;
};
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace seq_flow_lite {
namespace {
using ::testing::UnorderedElementsAreArray;
void TestFindSkipgrams(const SkipgramFinder& skipgram_finder,
const std::vector<std::string>& tokens,
const std::vector<int>& categories,
const std::vector<int>& token_categories) {
EXPECT_THAT(skipgram_finder.FindSkipgrams(absl::StrJoin(tokens, " ")),
UnorderedElementsAreArray(categories));
std::vector<absl::string_view> sv_tokens;
sv_tokens.reserve(tokens.size());
for (const auto& token : tokens) {
sv_tokens.emplace_back(token.data(), token.size());
}
EXPECT_THAT(skipgram_finder.FindSkipgrams(sv_tokens),
UnorderedElementsAreArray(token_categories));
}
// Test that u_tolower() will only increase the number of bytes in the
// UTF-8 encoding in two specific cases.
TEST(SkipgramFinderTest, UCharToLower) {
for (UChar32 c = 0; c < 0x10000; c++) {
if (c == 0x23a || c == 0x23e) continue;
UChar32 l = u_tolower(c);
EXPECT_GE(U8_LENGTH(c), U8_LENGTH(l)) << c << " lowercases to " << l;
}
}
TEST(SkipgramFinderTest, SingleExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "q", "r", "s", "c"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"a", "q", "xyz", "R!", "xy", "s", "c"},
{0}, {});
TestFindSkipgrams(skipgram_finder, {"a", "q", "r", "q", "R", "s.", "c"}, {0},
{});
}
TEST(SkipgramFinderTest, SingleNotExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "q", "x", "x", "r", "x", "s", "c"},
{}, {});
TestFindSkipgrams(skipgram_finder, {"a", "q", "x", "r", "x", "c"}, {}, {});
TestFindSkipgrams(skipgram_finder, {"a", "r", "x", "s", "q", "c"}, {}, {});
}
TEST(SkipgramFinderTest, SinglePrefixExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q.* r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "qa", "r", "s", "c"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"a", "q", "xyz", "R!", "xy", "s", "c"},
{0}, {});
TestFindSkipgrams(skipgram_finder, {"a", "qc", "r", "qd", "R", "s.", "c"},
{0}, {});
}
TEST(SkipgramFinderTest, SinglePrefixNotExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q.* r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "aq", "r", "s", "c"}, {}, {});
TestFindSkipgrams(skipgram_finder, {"a", "aqc", "xyz", "R!", "xy", "s", "c"},
{}, {});
TestFindSkipgrams(skipgram_finder, {"a", "q", "ar", "q", "aR", "s.", "c"}, {},
{});
}
TEST(SkipgramFinderTest, Punctuation) {
SkipgramFinder skipgram_finder(1);
std::string s("a-b-c def");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"q", "abc", "q", "d-e-f", "q"}, {0}, {});
TestFindSkipgrams(skipgram_finder, {"a", "'abc'", "q", "'def'", "q"}, {0},
{});
TestFindSkipgrams(skipgram_finder, {"q", "abc", "q", "def", "q"}, {0}, {0});
}
TEST(SkipgramFinderTest, HandlesMultibyteInput) {
SkipgramFinder skipgram_finder(1);
std::string s("hello\363\243\243\243!");
skipgram_finder.AddSkipgram(s, 0);
}
TEST(SkipgramFinderTest, Multiple) {
SkipgramFinder skipgram_finder(1);
std::string s1("a b c");
std::string s2("D e. F!");
std::string s3("ghi jkl mno");
std::string s4("S T U");
std::string s5("x. y, z!");
std::string s6("d.* e f");
skipgram_finder.AddSkipgram(s1, 0);
skipgram_finder.AddSkipgram(s2, 2);
skipgram_finder.AddSkipgram(s3, 4);
skipgram_finder.AddSkipgram(s4, 6);
skipgram_finder.AddSkipgram(s5, 8);
skipgram_finder.AddSkipgram(s6, 10);
TestFindSkipgrams(skipgram_finder, {"a", "d", "b", "e", "c", "f"}, {0, 2, 10},
{0, 2, 10});
TestFindSkipgrams(skipgram_finder, {"a", "dq", "b", "e", "c", "f"}, {0, 10},
{0, 10});
TestFindSkipgrams(skipgram_finder, {"a", "d", "b", "eq", "c", "f"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"a", "ghi", "b", "jkl", "c", "x", "mno"},
{0}, {0});
TestFindSkipgrams(skipgram_finder, {"ghi", "d", "jkl", "e", "mno", "f"},
{2, 4, 10}, {2, 4, 10});
TestFindSkipgrams(skipgram_finder, {"s", "x", "t", "y", "u", "z"}, {6, 8},
{6, 8});
}
TEST(SkipgramFinderTest, UnicodeLowercase) {
// Check that the lowercase has a smaller UTF-8 encoding than the uppercase.
UChar32 cu;
U8_GET_UNSAFE("Ɦ", 0, cu);
UChar32 cl = u_tolower(cu);
EXPECT_GT(U8_LENGTH(cu), U8_LENGTH(cl));
SkipgramFinder skipgram_finder(1);
std::string s("Ɦ");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"Ɦ"}, {0}, {});
TestFindSkipgrams(skipgram_finder, {"ɦ"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"h"}, {}, {});
}
} // namespace
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace seq_flow_lite {
void SubsequenceFinder::AddSubsequence(absl::string_view subsequence,
int category) {
const char* s = subsequence.data();
int32_t length = subsequence.length();
int32_t n = 0;
TrieNode* trie = &subsequence_trie_;
bool new_word = true;
while (n < length) {
UChar32 c;
U8_NEXT(s, n, length, c);
if (c < 0) return;
c = u_tolower(c);
if (c == ' ') {
new_word = true;
} else if (!new_word) {
trie = &trie->continue_token[c];
} else {
trie = &trie->next_token[c];
new_word = false;
}
}
trie->categories.insert(category);
}
// Given a UChar32 and a trie node representing an in-progress subsequence,
// determine if we can use the UChar32 to continue the subsequence, and
// update `categories`, `next_tokens`, and `continue_tokens` if needed.
void SubsequenceFinder::ProcessUChar32AndTrieNode(
int index, UChar32 c,
const absl::flat_hash_map<UChar32, TrieNode>& token_map,
absl::flat_hash_set<int>* categories,
std::deque<std::pair<int, const TrieNode*>>* next_tokens,
std::vector<const TrieNode*>* continue_tokens) const {
auto iter = token_map.find(c);
if (iter != token_map.end()) {
categories->insert(iter->second.categories.begin(),
iter->second.categories.end());
if (!iter->second.continue_token.empty()) {
continue_tokens->push_back(&iter->second);
}
if (!iter->second.next_token.empty()) {
next_tokens->emplace_back(index, &iter->second);
}
}
}
absl::flat_hash_set<int> SubsequenceFinder::FindSubsequences(
absl::string_view input) const {
absl::flat_hash_set<int> categories;
// Tracks subsequences in progress that are starting the next token,
// as well as the index of their last character.
std::deque<std::pair<int, const TrieNode*>> next_tokens;
// Tracks subsequences in progress that are looking for the next character
// in their corrent token. `current_continue_tokens` is the current set of
// subsequences being processed, while `future_continue_tokens` is the set
// of subsequences to process for the next character.
std::vector<const TrieNode*> current_continue_tokens;
std::vector<const TrieNode*> future_continue_tokens;
const char* s = input.data();
int32_t length = input.length();
int32_t n = 0;
int index = 0;
while (n < length) {
UChar32 c;
U8_NEXT(s, n, length, c);
if (c < 0) return categories;
c = u_tolower(c);
// Drop any subsequences which would need to skip more than `max_skip_size_`
// characters between the end of their last token and the current character.
while (!next_tokens.empty()) {
if (next_tokens.front().first + max_skip_size_ + 1 < index) {
next_tokens.pop_front();
} else {
break;
}
}
// Check subsequences starting a new token.
size_t size = next_tokens.size();
for (size_t i = 0; i < size; i++) {
ProcessUChar32AndTrieNode(index, c, next_tokens[i].second->next_token,
&categories, &next_tokens,
&future_continue_tokens);
}
// Check subsequences continuing a token.
for (const TrieNode* continue_token : current_continue_tokens) {
ProcessUChar32AndTrieNode(index, c, continue_token->continue_token,
&categories, &next_tokens,
&future_continue_tokens);
}
// Check if we can start a new subsequence.
ProcessUChar32AndTrieNode(index, c, subsequence_trie_.next_token,
&categories, &next_tokens,
&future_continue_tokens);
current_continue_tokens.swap(future_continue_tokens);
future_continue_tokens.clear();
index++;
}
return categories;
}
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
namespace seq_flow_lite {
// SubsequenceFinder finds subsequences in UTF-8 strings.
//
// Specifically, given a subsequence t_1 t_2 ... t_n, we will check if a
// string matches '.*t_1.{0,N}t_2.{0,N} ... .{0,N}t_n.*', where N is the
// maximum skip size.
//
// To use: First, add subsequences using AddSubsequence() - each subsequence
// is associated with some category. Then call FindSubsequences() on a string,
// which will return the set of categories of the subsesequences in the string.
//
// The subsequences will be tokenized by splitting on spaces. Both subsequences
// and input strings will be normalized by lowercasing.
class SubsequenceFinder {
public:
explicit SubsequenceFinder(int max_skip_size)
: max_skip_size_(max_skip_size) {}
// Adds a subsequence that SubsequenceFinder should look for in input strings.
void AddSubsequence(absl::string_view subsequence, int category);
// Find all of the subsequences in `input`, and return their categories.
absl::flat_hash_set<int> FindSubsequences(absl::string_view input) const;
private:
// This trie tracks the next character needed to:
// * continue the current token
// * start the next token
struct TrieNode {
absl::flat_hash_set<int> categories;
absl::flat_hash_map<UChar32, TrieNode> continue_token;
absl::flat_hash_map<UChar32, TrieNode> next_token;
};
void ProcessUChar32AndTrieNode(
int index, UChar32 c,
const absl::flat_hash_map<UChar32, TrieNode>& token_map,
absl::flat_hash_set<int>* categories,
std::deque<std::pair<int, const TrieNode*>>* next_tokens,
std::vector<const TrieNode*>* continue_tokens) const;
TrieNode subsequence_trie_;
int max_skip_size_;
};
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace seq_flow_lite {
namespace {
using ::testing::UnorderedElementsAre;
TEST(SubsequenceFinderTest, SingleExists) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("ab cd", 0);
EXPECT_THAT(subsequence_finder.FindSubsequences("abcd"),
UnorderedElementsAre(0));
EXPECT_THAT(subsequence_finder.FindSubsequences("ab012cd"),
UnorderedElementsAre(0));
EXPECT_THAT(subsequence_finder.FindSubsequences("AB CD"),
UnorderedElementsAre(0));
}
TEST(SubsequenceFinderTest, SingleNotExists) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("ab cd", 0);
EXPECT_THAT(subsequence_finder.FindSubsequences("a bcd"),
UnorderedElementsAre());
EXPECT_THAT(subsequence_finder.FindSubsequences("ab0123cd"),
UnorderedElementsAre());
EXPECT_THAT(subsequence_finder.FindSubsequences("abdc"),
UnorderedElementsAre());
}
TEST(SubsequenceFinderTest, Multiple) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("a b c d", 0);
subsequence_finder.AddSubsequence("q r s", 2);
subsequence_finder.AddSubsequence("b c d e", 4);
EXPECT_THAT(subsequence_finder.FindSubsequences("a__b__c__d__e"),
UnorderedElementsAre(0, 4));
EXPECT_THAT(subsequence_finder.FindSubsequences("aqbrcsd"),
UnorderedElementsAre(0, 2));
EXPECT_THAT(subsequence_finder.FindSubsequences("b q c r d s e"),
UnorderedElementsAre(2, 4));
}
TEST(SubsequenceFinderTest, Utf8) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("一二 三四 五六", 0);
EXPECT_THAT(subsequence_finder.FindSubsequences("一二おはよ三四こんに五六"),
UnorderedElementsAre(0));
EXPECT_THAT(subsequence_finder.FindSubsequences("一二三 四五六"),
UnorderedElementsAre());
}
} // namespace
} // namespace seq_flow_lite
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_TEXT_DISTORTER_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_TEXT_DISTORTER_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_TEXT_DISTORTER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_TEXT_DISTORTER_H_
#include <assert.h>
......@@ -40,4 +40,4 @@ class TextDistorter {
UChar32 random_char_ = 0;
};
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_TEXT_DISTORTER_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_TEXT_DISTORTER_H_
......@@ -122,4 +122,4 @@ REGISTER_OP("UniformCausalAttn")
})
.Doc(R"doc(
Dummy uniform causal attn op.
)doc";
)doc");
......@@ -121,9 +121,9 @@ cc_library(
hdrs = ["tflite_qrnn_pooling.h"],
copts = tflite_copts(),
deps = [
"//third_party/absl/base:core_headers",
"//third_party/tensorflow/lite/kernels:builtin_ops",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"//tflite_ops:quantization_util", # sequence projection
"@com_google_absl//absl/base:core_headers",
],
alwayslink = 1,
)
......@@ -132,7 +132,7 @@ cc_library(
name = "tflite_decoder_cache",
hdrs = ["tflite_decoder_cache.h"],
deps = [
"//third_party/tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/c:common",
],
alwayslink = 1,
)
......@@ -144,12 +144,12 @@ cc_library(
copts = tflite_copts(),
deps = [
":tflite_decoder_cache",
"//third_party/flatbuffers",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels:builtin_ops",
"//third_party/tensorflow/lite/kernels:kernel_util",
"//third_party/tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"//tflite_ops:quantization_util", # sequence projection
"@flatbuffers",
],
alwayslink = 1,
)
......@@ -160,11 +160,11 @@ cc_test(
srcs = ["tflite_decoder_handler_test.cc"],
deps = [
":tflite_decoder_handler",
"//testing/base/public:gunit",
"//third_party/flatbuffers",
"//third_party/tensorflow/lite:framework",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels:test_util",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
......@@ -176,10 +176,10 @@ cc_library(
deps = [
"//base",
"//third_party/absl/strings",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow/lite/kernels/internal:types",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
"//tflite_ops:quantization_util", # sequence projection
],
)
......@@ -189,14 +189,14 @@ cc_test(
copts = tflite_copts(),
deps = [
":beam_search",
"//testing/base/public:gunit_main",
"//third_party/absl/strings",
"//third_party/tensorflow/lite/c:c_api_types",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels/internal:legacy_reference_base",
"//third_party/tensorflow/lite/kernels/internal:optimized_base",
"//third_party/tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow/lite/kernels/internal:types",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels/internal:legacy_reference_base",
"@org_tensorflow//tensorflow/lite/kernels/internal:optimized_base",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
"//tflite_ops:quantization_util", # sequence projection
"@com_google_googletest//:gtest_main",
],
)
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h" // seq_flow_lite
#include <algorithm>
#include <cstdint>
......@@ -21,10 +21,10 @@ limitations under the License.
#include <vector>
#include "base/logging.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
......@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
std::vector<std::vector<int32_t>> SequenceTracker::GetTopBeams() {
std::vector<std::vector<int32_t>> return_value;
return_value.reserve(terminated_topk_.size());
for (const auto &v : terminated_topk_) {
return_value.push_back(v.second);
}
......@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
}
}
// Updating topk across all beams.
for (uint32_t k = 0; k < std::min(topk_k, num_classes_); ++k) {
const uint32_t curr_beam_index = curr_beam_topk[k] & kClassIndexMask;
for (uint32_t curr_beam : curr_beam_topk) {
const uint32_t curr_beam_index = curr_beam & kClassIndexMask;
const uint32_t index = j * num_classes_ + curr_beam_index;
const float log_prob =
tensor.params.scale * beam_logits[curr_beam_index] - precomputed;
......
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#include <cstdint>
#include <functional>
......@@ -23,7 +23,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "third_party/tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace seq_flow_lite {
namespace ops {
......@@ -110,4 +110,4 @@ class BeamSearch {
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h" // seq_flow_lite
#include <cstdint>
#include <functional>
......@@ -21,17 +21,17 @@ limitations under the License.
#include <memory>
#include <vector>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/tensorflow/lite/c/c_api_types.h"
#include "third_party/tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_join.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
......@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch {
cur_cache + (selected_beams[beam] * NumClasses());
for (int j = 0; j < NumClasses(); ++j, index++) {
next_cache[index] = (selected[j] + next_cache[index]) / 2;
data_ptr[index] = ::seq_flow_lite::PodQuantize(
data_ptr[index] = PodQuantize(
next_cache[index], decoder_output_->params.zero_point,
1.0f / decoder_output_->params.scale);
}
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
#include "tensorflow/lite/kernels/register.h"
......@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
} // namespace ops
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
#define LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
#include "tensorflow/lite/kernels/register.h"
......@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM();
} // namespace ops
} // namespace seq_flow_lite
#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
......@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/-1.0, /*offset=*/0.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/1.0, /*offset=*/-1.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/-1.0, /*offset=*/-1.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) {
/*input_max=*/3, /*output_min=*/-3, /*output_max=*/3,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
/*input_max=*/3, /*output_min=*/-3, /*output_max=*/3,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{-3, -1});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
/*input_max=*/1.0, /*output_min=*/-3.0, /*output_max=*/3.0,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm>
#include <cmath>
......@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
......@@ -101,7 +101,7 @@ class ProjectionParams {
bool exclude_nonalphaspace_unicodes,
const std::string& token_separators,
bool normalize_repetition, bool add_first_cap_feature,
bool add_all_caps_feature)
bool add_all_caps_feature, bool normalize_spaces)
: feature_size_(feature_size),
unicode_handler_(vocabulary, exclude_nonalphaspace_unicodes),
hasher_(Hasher::CreateHasher(feature_size, hashtype)),
......@@ -130,9 +130,9 @@ class ProjectionParams {
}
word_novelty_offset_ = 2.0f / (1 << word_novelty_bits_);
if (!token_separators.empty() || normalize_repetition) {
if (!token_separators.empty() || normalize_repetition || normalize_spaces) {
projection_normalizer_ = std::make_unique<ProjectionNormalizer>(
token_separators, normalize_repetition);
token_separators, normalize_repetition, normalize_spaces);
}
}
virtual ~ProjectionParams() {}
......@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams {
/*exclude_nonalphaspace_unicodes = */ false,
/*token_separators = */ "", normalize_repetition,
/*add_first_cap_feature = */ false,
/*add_all_caps_feature = */ false) {}
/*add_all_caps_feature = */ false,
/*normalize_spaces = */ false) {}
~ProjectionParamsV2() override {}
TfLiteStatus PreprocessInput(TfLiteTensor* input_t,
......@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const std::string token_separators =
m["token_separators"].IsNull() ? "" : m["token_separators"].ToString();
const bool normalize_repetition = m["normalize_repetition"].AsBool();
const bool normalize_spaces = m["normalize_spaces"].AsBool();
if (!Hasher::SupportedHashType(hashtype)) {
context->ReportError(context, "Unsupported hashtype %s\n",
hashtype.c_str());
......@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
add_bos_tag ? BosTag::kGenerate : BosTag::kNone,
add_eos_tag ? EosTag::kGenerate : EosTag::kNone,
exclude_nonalphaspace_unicodes, token_separators, normalize_repetition,
add_first_cap_feature == 1.0f, add_all_caps_feature == 1.0f);
add_first_cap_feature == 1.0f, add_all_caps_feature == 1.0f,
normalize_spaces);
}
void* InitV2(TfLiteContext* context, const char* buffer, size_t length) {
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
......@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
extern const char kSequenceStringProjectionV2[];
TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION_V2();
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
......@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase;
using ::seq_flow_lite::testing::StringTensor;
using ::seq_flow_lite::testing::TensorflowTfLiteOpTest;
using ::testing::ElementsAreArray;
using ::testing::Not;
using ::tflite::TensorType_FLOAT32;
using ::tflite::TensorType_STRING;
using ::tflite::TensorType_UINT8;
......@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
int doc_size_levels, bool add_eos_tag, ::tflite::TensorType output_type,
const std::string& token_separators = "",
bool normalize_repetition = false, float add_first_cap = 0.0,
float add_all_caps = 0.0, const std::string& hashtype = kMurmurHash) {
float add_all_caps = 0.0, const std::string& hashtype = kMurmurHash,
bool normalize_spaces = false) {
flexbuffers::Builder fbb;
fbb.Map([&] {
fbb.Int("feature_size", 4);
......@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
fbb.Bool("normalize_repetition", normalize_repetition);
fbb.Float("add_first_cap_feature", add_first_cap);
fbb.Float("add_all_caps_feature", add_all_caps);
fbb.Bool("normalize_spaces", normalize_spaces);
});
fbb.Finish();
output_ = AddOutput({output_type, {}});
......@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
SingleOpModel::Invoke();
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
}
TfLiteStatus InvokeFailable(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
return SingleOpModel::InvokeUnchecked();
return SingleOpModel::Invoke();
}
template <typename T>
......@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) {
EXPECT_THAT(output1, ElementsAreArray(output2));
}
TEST(SequenceStringProjectionTest, NormalizeSpaces) {
SequenceStringProjectionModel model_nonormalize(false, -1, 0, 0, false,
TensorType_UINT8, "", false,
0.0, 0.0, kMurmurHash, false);
SequenceStringProjectionModel model_normalize(false, -1, 0, 0, false,
TensorType_UINT8, "", false,
0.0, 0.0, kMurmurHash, true);
const char kNoExtraSpaces[] = "Hello there.";
const char kExtraSpaces[] = " Hello there. ";
model_nonormalize.Invoke(kNoExtraSpaces);
auto output_noextra_nonorm = model_nonormalize.GetOutput<uint8_t>();
model_nonormalize.Invoke(kExtraSpaces);
auto output_extra_nonorm = model_nonormalize.GetOutput<uint8_t>();
model_normalize.Invoke(kNoExtraSpaces);
auto output_noextra_norm = model_normalize.GetOutput<uint8_t>();
model_normalize.Invoke(kExtraSpaces);
auto output_extra_norm = model_normalize.GetOutput<uint8_t>();
EXPECT_THAT(output_noextra_nonorm, ElementsAreArray(output_noextra_norm));
EXPECT_THAT(output_noextra_nonorm, ElementsAreArray(output_extra_norm));
EXPECT_THAT(output_noextra_nonorm,
Not(ElementsAreArray(output_extra_nonorm)));
}
class SequenceStringProjectionTest : public TensorflowTfLiteOpTest {
std::function<TfLiteRegistration*()> TfLiteOpRegistration() override {
return ops::custom::Register_SEQUENCE_STRING_PROJECTION;
......@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NormalizeRepetition";
......@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NormalizeSpaces";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.attributes["normalize_spaces"] = AttrValue(true);
test_case.input_tensors.push_back(StringTensor({1}, {" Hello there. "}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
return test_cases;
}
......@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel {
PopulateStringTensor(input_, input);
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
ASSERT_EQ(SingleOpModel::InvokeUnchecked(), expected);
ASSERT_EQ(SingleOpModel::Invoke(), expected);
}
TfLiteStatus InvokeFailable(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
return SingleOpModel::InvokeUnchecked();
return SingleOpModel::Invoke();
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
......
......@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() {
input_index++;
}
tflite_op_.Invoke();
ASSERT_EQ(tflite_op_.Invoke(), kTfLiteOk);
}
void TensorflowTfLiteOpTest::CompareOpOutput() {
......
......@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Tests equivalence between TF and TFLite versions of an op.
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#include <string>
#include <vector>
......@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest
} // namespace testing
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
......@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#include <memory>
#include "third_party/tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
......@@ -113,4 +114,4 @@ class DynamicCacheOp {
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
......@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_cache.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_decoder_cache.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
......
......@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#include "third_party/tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_UNIFORM_CAUSAL_ATTENTION();
}
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
......@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint>
#include <cstdlib>
#include <vector>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/test_util.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace {
......
......@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_qrnn_pooling.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
namespace custom {
namespace {
......@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
return QRNNPooling(context, multiplier, constant, outputs, final_state,
(direction->data.uint8[0] == kPoolingForward));
}
} // namespace
namespace custom {
const char kPoolingOp[] = "PoolingOp";
void RegisterQRNNPooling(::tflite::ops::builtin::BuiltinOpResolver* resolver) {
......@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() {
}
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
......@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#include "third_party/absl/base/macros.h"
#include "third_party/tensorflow/lite/kernels/register.h"
#include "absl/base/macros.h"
#include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
extern const char kPoolingOp[];
......@@ -27,7 +27,7 @@ extern const char kPoolingOp[];
TfLiteRegistration* Register_QRNN_POOLING();
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A utility for PRADO model to do train, eval, inference and model export."""
import importlib
......@@ -22,6 +21,7 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import input_fn_reader # import root module
import metric_functions # import root module
......@@ -48,14 +48,14 @@ def load_runner_config():
return json.loads(f.read())
def create_model(model, model_config, features, mode):
def create_model(model, model_config, features, mode, model_name):
"""Creates a sequence labeling model."""
keras_model = model.Encoder(model_config, mode)
if "pqrnn" in model_name:
logits = keras_model(features["projection"], features["seq_length"])
else:
logits = keras_model(features["token_ids"], features["token_len"])
if mode != tf.estimator.ModeKeys.PREDICT:
if mode != tf_estimator.ModeKeys.PREDICT:
if not model_config["multilabel"]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=features["label"], logits=logits)
......@@ -94,33 +94,33 @@ def model_fn_builder(runner_config):
def model_fn(features, mode, params):
"""The `model_fn` for TPUEstimator."""
label_ids = None
if mode != tf.estimator.ModeKeys.PREDICT:
if mode != tf_estimator.ModeKeys.PREDICT:
label_ids = features["label"]
model_config = runner_config["model_config"]
loss, logits = create_model(model, model_config, features, mode)
loss, logits = create_model(model, model_config, features, mode,
runner_config["name"])
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf_estimator.ModeKeys.TRAIN:
train_op = create_optimizer(loss, runner_config, params)
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op)
elif mode == tf.estimator.ModeKeys.EVAL:
elif mode == tf_estimator.ModeKeys.EVAL:
if not runner_config["model_config"]["multilabel"]:
metric_fn = metric_functions.classification_metric
else:
metric_fn = metric_functions.labeling_metric
eval_metrics = (metric_fn, [loss, label_ids, logits])
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=eval_metrics)
elif mode == tf.estimator.ModeKeys.PREDICT:
elif mode == tf_estimator.ModeKeys.PREDICT:
predictions = {"logits": logits}
if not runner_config["model_config"]["multilabel"]:
predictions["predictions"] = tf.nn.softmax(logits)
else:
predictions["predictions"] = tf.math.sigmoid(logits)
return tf.compat.v1.estimator.EstimatorSpec(
mode=mode, predictions=predictions)
return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions)
else:
assert False, "Expected to be called in TRAIN, EVAL, or PREDICT mode."
......@@ -133,13 +133,13 @@ def main(_):
if FLAGS.output_dir:
tf.gfile.MakeDirs(FLAGS.output_dir)
is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.estimator.tpu.RunConfig(
is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf_estimator.tpu.RunConfig(
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=runner_config["save_checkpoints_steps"],
keep_checkpoint_max=20,
tpu_config=tf.estimator.tpu.TPUConfig(
tpu_config=tf_estimator.tpu.TPUConfig(
iterations_per_loop=runner_config["iterations_per_loop"],
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
......@@ -149,7 +149,7 @@ def main(_):
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
batch_size = runner_config["batch_size"]
estimator = tf.estimator.tpu.TPUEstimator(
estimator = tf_estimator.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
......@@ -160,7 +160,7 @@ def main(_):
if FLAGS.runner_mode == "train":
train_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config,
mode=tf.estimator.ModeKeys.TRAIN,
mode=tf_estimator.ModeKeys.TRAIN,
drop_remainder=True)
estimator.train(
input_fn=train_input_fn, max_steps=runner_config["train_steps"])
......@@ -168,7 +168,7 @@ def main(_):
# TPU needs fixed shapes, so if the last batch is smaller, we drop it.
eval_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config,
mode=tf.estimator.ModeKeys.EVAL,
mode=tf_estimator.ModeKeys.EVAL,
drop_remainder=True)
for _ in tf.train.checkpoints_iterator(FLAGS.output_dir, timeout=600):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Binary to train PRADO model with TF 2.0."""
import importlib
......@@ -23,6 +22,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
from tensorflow import estimator as tf_estimator
import input_fn_reader # import root module
......@@ -48,7 +48,7 @@ def load_runner_config():
def compute_loss(logits, labels, model_config, mode):
"""Creates a sequence labeling model."""
if mode != tf.estimator.ModeKeys.PREDICT:
if mode != tf_estimator.ModeKeys.PREDICT:
if not model_config["multilabel"]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
......@@ -77,11 +77,11 @@ def main(_):
if FLAGS.output_dir:
tf.io.gfile.makedirs(FLAGS.output_dir)
train_model = model_fn_builder(runner_config, tf.estimator.ModeKeys.TRAIN)
train_model = model_fn_builder(runner_config, tf_estimator.ModeKeys.TRAIN)
optimizer = tf.keras.optimizers.Adam()
train_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config,
mode=tf.estimator.ModeKeys.TRAIN,
mode=tf_estimator.ModeKeys.TRAIN,
drop_remainder=True)
params = {"batch_size": runner_config["batch_size"]}
train_ds = train_input_fn(params)
......@@ -93,7 +93,7 @@ def main(_):
logits = train_model(features["projection"], features["seq_length"])
loss = compute_loss(logits, features["label"],
runner_config["model_config"],
tf.estimator.ModeKeys.TRAIN)
tf_estimator.ModeKeys.TRAIN)
gradients = tape.gradient(loss, train_model.trainable_variables)
optimizer.apply_gradients(zip(gradients, train_model.trainable_variables))
train_loss(loss)
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A module for miscelaneous utils."""
import tensorflow as tf
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Utils to convert to a TFLite model."""
import tensorflow.compat.v1 as tf
......@@ -65,9 +64,14 @@ def get_mean_stddev_values(min_value_of_features, max_value_of_features):
class InterpreterWithCustomOps(tf.lite.Interpreter):
"""Extended tf.lite.Interpreter."""
def __init__(self, model_content, custom_op_registerers=None):
def __init__(self,
model_content,
custom_op_registerers=None,
experimental_preserve_all_tensors=False):
self._custom_op_registerers = custom_op_registerers or []
super(InterpreterWithCustomOps, self).__init__(model_content=model_content)
super(InterpreterWithCustomOps, self).__init__(
model_content=model_content,
experimental_preserve_all_tensors=experimental_preserve_all_tensors)
def op_details(self):
op_details = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册