Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
20cc2190
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
20cc2190
编写于
8月 24, 2022
作者:
P
pyoung2778
提交者:
GitHub
8月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Check in seq_flow_lite (#10750)
上级
fdecf385
变更
62
隐藏空白更改
内联
并排
Showing
62 changed file
with
1983 addition
and
298 deletion
+1983
-298
research/seq_flow_lite/WORKSPACE
research/seq_flow_lite/WORKSPACE
+3
-6
research/seq_flow_lite/export_to_tflite.py
research/seq_flow_lite/export_to_tflite.py
+19
-12
research/seq_flow_lite/input_fn_reader.py
research/seq_flow_lite/input_fn_reader.py
+4
-4
research/seq_flow_lite/layers/BUILD
research/seq_flow_lite/layers/BUILD
+11
-11
research/seq_flow_lite/layers/base_layers.py
research/seq_flow_lite/layers/base_layers.py
+7
-2
research/seq_flow_lite/layers/conv_layers.py
research/seq_flow_lite/layers/conv_layers.py
+0
-1
research/seq_flow_lite/layers/dense_layers.py
research/seq_flow_lite/layers/dense_layers.py
+8
-3
research/seq_flow_lite/layers/embedding_layers.py
research/seq_flow_lite/layers/embedding_layers.py
+2
-2
research/seq_flow_lite/layers/misc_layers.py
research/seq_flow_lite/layers/misc_layers.py
+2
-3
research/seq_flow_lite/layers/normalization_layers.py
research/seq_flow_lite/layers/normalization_layers.py
+0
-1
research/seq_flow_lite/layers/qrnn_layers.py
research/seq_flow_lite/layers/qrnn_layers.py
+0
-1
research/seq_flow_lite/layers/quantization_layers.py
research/seq_flow_lite/layers/quantization_layers.py
+0
-1
research/seq_flow_lite/layers/transformer_layers.py
research/seq_flow_lite/layers/transformer_layers.py
+6
-6
research/seq_flow_lite/metric_functions.py
research/seq_flow_lite/metric_functions.py
+0
-1
research/seq_flow_lite/models/BUILD
research/seq_flow_lite/models/BUILD
+26
-26
research/seq_flow_lite/models/byteqrnn.py
research/seq_flow_lite/models/byteqrnn.py
+5
-5
research/seq_flow_lite/models/charformer.py
research/seq_flow_lite/models/charformer.py
+7
-7
research/seq_flow_lite/models/pqrnn.py
research/seq_flow_lite/models/pqrnn.py
+12
-8
research/seq_flow_lite/models/prado.py
research/seq_flow_lite/models/prado.py
+0
-1
research/seq_flow_lite/models/sgnn/sgnn_test.py
research/seq_flow_lite/models/sgnn/sgnn_test.py
+0
-1
research/seq_flow_lite/models/transformer_encoder.py
research/seq_flow_lite/models/transformer_encoder.py
+2
-2
research/seq_flow_lite/models/transformer_uniform_attn_decoder.py
.../seq_flow_lite/models/transformer_uniform_attn_decoder.py
+6
-6
research/seq_flow_lite/tf_ops/BUILD
research/seq_flow_lite/tf_ops/BUILD
+102
-26
research/seq_flow_lite/tf_ops/denylist_op.cc
research/seq_flow_lite/tf_ops/denylist_op.cc
+438
-0
research/seq_flow_lite/tf_ops/denylist_op_test.cc
research/seq_flow_lite/tf_ops/denylist_op_test.cc
+292
-0
research/seq_flow_lite/tf_ops/denylist_op_test.py
research/seq_flow_lite/tf_ops/denylist_op_test.py
+63
-0
research/seq_flow_lite/tf_ops/projection_normalizer_util.cc
research/seq_flow_lite/tf_ops/projection_normalizer_util.cc
+29
-1
research/seq_flow_lite/tf_ops/projection_normalizer_util.h
research/seq_flow_lite/tf_ops/projection_normalizer_util.h
+10
-6
research/seq_flow_lite/tf_ops/projection_tokenizer_util.h
research/seq_flow_lite/tf_ops/projection_tokenizer_util.h
+3
-3
research/seq_flow_lite/tf_ops/projection_util.h
research/seq_flow_lite/tf_ops/projection_util.h
+3
-3
research/seq_flow_lite/tf_ops/sequence_string_projection.cc
research/seq_flow_lite/tf_ops/sequence_string_projection.cc
+10
-2
research/seq_flow_lite/tf_ops/skipgram_finder.cc
research/seq_flow_lite/tf_ops/skipgram_finder.cc
+183
-0
research/seq_flow_lite/tf_ops/skipgram_finder.h
research/seq_flow_lite/tf_ops/skipgram_finder.h
+66
-0
research/seq_flow_lite/tf_ops/skipgram_finder_test.cc
research/seq_flow_lite/tf_ops/skipgram_finder_test.cc
+160
-0
research/seq_flow_lite/tf_ops/subsequence_finder.cc
research/seq_flow_lite/tf_ops/subsequence_finder.cc
+143
-0
research/seq_flow_lite/tf_ops/subsequence_finder.h
research/seq_flow_lite/tf_ops/subsequence_finder.h
+76
-0
research/seq_flow_lite/tf_ops/subsequence_finder_test.cc
research/seq_flow_lite/tf_ops/subsequence_finder_test.cc
+81
-0
research/seq_flow_lite/tf_ops/text_distorter.h
research/seq_flow_lite/tf_ops/text_distorter.h
+3
-3
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
+1
-1
research/seq_flow_lite/tflite_ops/BUILD
research/seq_flow_lite/tflite_ops/BUILD
+27
-27
research/seq_flow_lite/tflite_ops/beam_search.cc
research/seq_flow_lite/tflite_ops/beam_search.cc
+8
-7
research/seq_flow_lite/tflite_ops/beam_search.h
research/seq_flow_lite/tflite_ops/beam_search.h
+4
-4
research/seq_flow_lite/tflite_ops/beam_search_test.cc
research/seq_flow_lite/tflite_ops/beam_search_test.cc
+13
-13
research/seq_flow_lite/tflite_ops/expected_value.h
research/seq_flow_lite/tflite_ops/expected_value.h
+3
-3
research/seq_flow_lite/tflite_ops/layer_norm.h
research/seq_flow_lite/tflite_ops/layer_norm.h
+3
-3
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
+7
-7
research/seq_flow_lite/tflite_ops/quantization_util.h
research/seq_flow_lite/tflite_ops/quantization_util.h
+3
-3
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
...ch/seq_flow_lite/tflite_ops/sequence_string_projection.cc
+8
-5
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
...rch/seq_flow_lite/tflite_ops/sequence_string_projection.h
+4
-3
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
...q_flow_lite/tflite_ops/sequence_string_projection_test.cc
+49
-5
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
...arch/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
+1
-1
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
+3
-3
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
+5
-4
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
+7
-6
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
+6
-5
research/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
...h/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
+6
-6
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
+6
-4
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
+7
-7
research/seq_flow_lite/trainer.py
research/seq_flow_lite/trainer.py
+18
-18
research/seq_flow_lite/trainer_v2.py
research/seq_flow_lite/trainer_v2.py
+5
-5
research/seq_flow_lite/utils/misc_utils.py
research/seq_flow_lite/utils/misc_utils.py
+0
-1
research/seq_flow_lite/utils/tflite_utils.py
research/seq_flow_lite/utils/tflite_utils.py
+7
-3
未找到文件。
research/seq_flow_lite/WORKSPACE
浏览文件 @
20cc2190
...
...
@@ -16,14 +16,11 @@ http_archive(
http_archive
(
name
=
"org_tensorflow"
,
sha256
=
"40d3203ab5f246d83bae328288a24209a2b85794f1b3e2cd0329458d8e7c1985"
,
strip_prefix
=
"tensorflow-2.6.0"
,
urls
=
[
"https://github.com/tensorflow/tensorflow/archive/v2.6.0.zip"
,
],
strip_prefix
=
"tensorflow-2.9.1"
,
sha256
=
"9f2dac244e5af6c6a13a7dad6481e390174ac989931942098e7a4373f1bccfc2"
,
urls
=
[
"https://github.com/tensorflow/tensorflow/archive/v2.9.1.zip"
],
)
http_archive
(
name
=
"org_tflite_support"
,
strip_prefix
=
"tflite-support-0861599711ef31de58f62ed3ff6bbcc1e4817ef6"
,
...
...
research/seq_flow_lite/export_to_tflite.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A tool to export TFLite model."""
import
importlib
...
...
@@ -22,7 +21,7 @@ import os
from
absl
import
app
from
absl
import
flags
import
tensorflow.compat.v1
as
tf
import
tensorflow_text
as
tftext
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
projection_layers
# import seq_flow_lite module
from
utils
import
tflite_utils
# import seq_flow_lite module
...
...
@@ -48,25 +47,33 @@ def main(_):
with
tf
.
Graph
().
as_default
()
as
graph
:
with
tf
.
Session
(
graph
=
graph
)
as
session
:
text
=
tf
.
placeholder
(
tf
.
string
,
shape
=
[
1
],
name
=
"Input"
)
prxlayer
=
projection_layers
.
ProjectionLayer
(
model_config
,
base_layers
.
TFLITE
)
encoder
=
model
.
Encoder
(
model_config
,
base_layers
.
TFLITE
)
projection
,
seq_lengh
=
prxlayer
(
text
)
logits
=
encoder
(
projection
,
seq_lengh
)
inputs
=
[
text
]
if
"pqrnn"
in
runner_config
[
"name"
]:
prxlayer
=
projection_layers
.
ProjectionLayer
(
model_config
,
base_layers
.
TFLITE
)
encoder
=
model
.
Encoder
(
model_config
,
base_layers
.
TFLITE
)
projection
,
seq_length
=
prxlayer
(
text
)
logits
=
encoder
(
projection
,
seq_length
)
else
:
byte_int
=
tftext
.
ByteSplitter
().
split
(
text
)
token_ids
=
tf
.
cast
(
byte_int
,
tf
.
int32
).
to_tensor
()
token_ids
=
tf
.
reshape
(
token_ids
,
[
1
,
-
1
])
token_ids
+=
3
encoder
=
model
.
Encoder
(
model_config
,
base_layers
.
TFLITE
)
logits
=
encoder
(
token_ids
,
None
)
if
FLAGS
.
output
==
"logits"
:
outputs
=
logits
outputs
=
[
logits
]
elif
FLAGS
.
output
==
"sigmoid"
:
outputs
=
tf
.
math
.
sigmoid
(
logits
)
outputs
=
[
tf
.
math
.
sigmoid
(
logits
)]
else
:
assert
FLAGS
.
output
==
"softmax"
,
"Unexpected output"
outputs
=
tf
.
nn
.
softmax
(
logits
)
outputs
=
[
tf
.
nn
.
softmax
(
logits
)]
session
.
run
(
tf
.
global_variables_initializer
())
session
.
run
(
tf
.
local_variables_initializer
())
saver
=
tf
.
train
.
Saver
()
saver
.
restore
(
session
,
tf
.
train
.
latest_checkpoint
(
FLAGS
.
output_dir
))
tflite_fb
=
tflite_utils
.
generate_tflite
(
session
,
graph
,
[
text
],
[
outputs
])
tflite_fb
=
tflite_utils
.
generate_tflite
(
session
,
graph
,
inputs
,
outputs
)
output_file_name
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"tflite.fb"
)
with
tf
.
gfile
.
Open
(
output_file_name
,
"wb"
)
as
f
:
f
.
write
(
tflite_fb
)
...
...
research/seq_flow_lite/input_fn_reader.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Methods related to input datasets and readers."""
import
functools
...
...
@@ -21,6 +20,7 @@ import sys
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow
import
estimator
as
tf_estimator
import
tensorflow_datasets
as
tfds
import
tensorflow_text
as
tftext
...
...
@@ -83,13 +83,13 @@ def create_input_fn(runner_config, mode, drop_remainder):
def
_input_fn
(
params
):
"""Method to be used for reading the data."""
assert
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
split
=
"train"
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
else
"test"
assert
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
split
=
"train"
if
mode
==
tf
_
estimator
.
ModeKeys
.
TRAIN
else
"test"
ds
=
tfds
.
load
(
runner_config
[
"dataset"
],
split
=
split
)
ds
=
ds
.
batch
(
params
[
"batch_size"
],
drop_remainder
=
drop_remainder
)
ds
=
ds
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
ds
=
ds
.
shuffle
(
buffer_size
=
100
)
ds
=
ds
.
repeat
(
count
=
1
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
else
None
)
ds
=
ds
.
repeat
(
count
=
1
if
mode
==
tf
_
estimator
.
ModeKeys
.
EVAL
else
None
)
ds
=
ds
.
map
(
functools
.
partial
(
_post_processor
,
batch_size
=
params
[
"batch_size"
]),
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
,
...
...
research/seq_flow_lite/layers/BUILD
浏览文件 @
20cc2190
...
...
@@ -82,12 +82,12 @@ py_strict_library(
srcs
=
[
"misc_layers.py"
],
srcs_version
=
"PY3"
,
deps
=
[
# package tensorflow
":embedding_layers"
,
# package tensorflow
"//layers:base_layers"
,
# sequence projection
"//layers:conv_layers"
,
"//layers:conv_layers"
,
# sequence projection
"//layers:dense_layers"
,
# sequence projection
"//layers:normalization_layers"
,
"//layers:normalization_layers"
,
# sequence projection
"//layers:quantization_layers"
,
# sequence projection
],
)
...
...
@@ -112,8 +112,8 @@ py_strict_library(
srcs_version
=
"PY3"
,
deps
=
[
# package tensorflow
"//layers:base_layers"
,
"//layers:quantization_layers"
,
"//layers:base_layers"
,
# sequence projection
"//layers:quantization_layers"
,
# sequence projection
],
)
...
...
@@ -124,11 +124,11 @@ py_strict_library(
deps
=
[
":embedding_layers"
,
# package tensorflow
"//layers:base_layers"
,
"//layers:dense_layers"
,
"//layers:normalization_layers"
,
"//layers:quantization_layers"
,
"//tf_ops:tf_custom_ops"
,
"//tf_ops:tf_custom_ops_py"
,
"//layers:base_layers"
,
# sequence projection
"//layers:dense_layers"
,
# sequence projection
"//layers:normalization_layers"
,
# sequence projection
"//layers:quantization_layers"
,
# sequence projection
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py"
,
# sequence projection
],
)
research/seq_flow_lite/layers/base_layers.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Base layer for building models trained with quantization."""
import
tensorflow
as
tf
...
...
@@ -57,7 +56,7 @@ class BaseLayer(tf.keras.layers.Layer):
def
add_weight_wrapper
(
self
,
shape
):
"""Return a weight variable for the given shape."""
if
self
.
parameters
.
initializer
is
not
None
:
initializer
=
self
.
parameters
.
initializer
initializer
=
clone_initializer
(
self
.
parameters
.
initializer
)
else
:
initializer
=
tf
.
keras
.
initializers
.
GlorotUniform
()
weight
=
self
.
add_weight
(
...
...
@@ -136,3 +135,9 @@ class BaseLayer(tf.keras.layers.Layer):
maxval
=
(
1.0
-
zero_probability
),
dtype
=
tensor
.
dtype
)
return
tf
.
math
.
ceil
(
rnd
)
def
clone_initializer
(
initializer
):
if
isinstance
(
initializer
,
tf
.
keras
.
initializers
.
Initializer
):
return
initializer
.
__class__
.
from_config
(
initializer
.
get_config
())
return
initializer
research/seq_flow_lite/layers/conv_layers.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Base layer for convolution."""
import
tensorflow
as
tf
...
...
research/seq_flow_lite/layers/dense_layers.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Basic dense layers."""
import
tensorflow
as
tf
...
...
@@ -30,6 +29,7 @@ class BaseQDense(base_layers.BaseLayer):
bias
=
True
,
rank
=
2
,
normalize
=
True
,
quantize_output
=
True
,
**
kwargs
):
self
.
units
=
units
self
.
rank
=
rank
...
...
@@ -37,7 +37,9 @@ class BaseQDense(base_layers.BaseLayer):
self
.
activation
=
activation
self
.
bias
=
bias
self
.
normalize
=
normalize
self
.
qoutput
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
quantize_output
=
quantize_output
if
quantize_output
:
self
.
qoutput
=
quantization_layers
.
ActivationQuantization
(
**
kwargs
)
self
.
_create_normalizer
(
**
kwargs
)
super
(
BaseQDense
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -62,7 +64,10 @@ class BaseQDense(base_layers.BaseLayer):
outputs
=
normalize_method
(
outputs
)
if
self
.
activation
:
outputs
=
self
.
activation
(
outputs
)
return
self
.
qoutput
(
outputs
)
if
self
.
quantize_output
:
return
self
.
qoutput
(
outputs
)
else
:
return
outputs
def
_dense_r34
(
self
,
inputs
,
normalize_method
):
bsz
=
self
.
get_batch_dimension
(
inputs
)
...
...
research/seq_flow_lite/layers/embedding_layers.py
浏览文件 @
20cc2190
...
...
@@ -15,8 +15,8 @@
"""Layers for embedding."""
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
quantization_layers
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
quantization_layers
# import seq_flow_lite module
class
EmbeddingLayer
(
base_layers
.
BaseLayer
):
...
...
research/seq_flow_lite/layers/misc_layers.py
浏览文件 @
20cc2190
...
...
@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for embedding."""
import
math
import
tensorflow
as
tf
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
conv_layers
from
layers
import
conv_layers
# import seq_flow_lite module
from
layers
import
dense_layers
# import seq_flow_lite module
from
layers
import
embedding_layers
from
layers
import
embedding_layers
# import seq_flow_lite module
from
layers
import
quantization_layers
# import seq_flow_lite module
...
...
research/seq_flow_lite/layers/normalization_layers.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for normalization."""
import
tensorflow
as
tf
...
...
research/seq_flow_lite/layers/qrnn_layers.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for QRNN."""
import
tensorflow
as
tf
...
...
research/seq_flow_lite/layers/quantization_layers.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for quantization."""
import
tensorflow
as
tf
...
...
research/seq_flow_lite/layers/transformer_layers.py
浏览文件 @
20cc2190
...
...
@@ -16,12 +16,12 @@
# pylint: disable=arguments-renamed
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
dense_layers
from
layers
import
embedding_layers
from
layers
import
normalization_layers
from
layers
import
quantization_layers
from
tf_ops
import
tf_custom_ops_py
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
dense_layers
# import seq_flow_lite module
from
layers
import
embedding_layers
# import seq_flow_lite module
from
layers
import
normalization_layers
# import seq_flow_lite module
from
layers
import
quantization_layers
# import seq_flow_lite module
from
tf_ops
import
tf_custom_ops_py
# import seq_flow_lite module
class
SelfAttention
(
base_layers
.
BaseLayer
):
...
...
research/seq_flow_lite/metric_functions.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Metric functions."""
import
tensorflow.compat.v1
as
tf
...
...
research/seq_flow_lite/models/BUILD
浏览文件 @
20cc2190
...
...
@@ -45,13 +45,13 @@ py_library(
srcs_version
=
"PY3"
,
deps
=
[
# package tensorflow
"//layers:base_layers"
,
"//layers:dense_layers"
,
"//layers:embedding_layers"
,
"//layers:misc_layers"
,
"//layers:qrnn_layers"
,
#
//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py"
,
"//layers:base_layers"
,
# sequence projection
"//layers:dense_layers"
,
# sequence projection
"//layers:embedding_layers"
,
# sequence projection
"//layers:misc_layers"
,
# sequence projection
"//layers:qrnn_layers"
,
# sequence projection
#
"//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py"
,
# sequence projection
],
)
...
...
@@ -62,13 +62,13 @@ py_library(
deps
=
[
":transformer_encoder"
,
# package tensorflow
"//layers:base_layers"
,
"//layers:embedding_layers"
,
"//layers:misc_layers"
,
"//layers:normalization_layers"
,
"//layers:quantization_layers"
,
# "//tf_ops:tf_custom_ops"
,
"//tf_ops:tf_custom_ops_py"
,
"//layers:base_layers"
,
# sequence projection
"//layers:embedding_layers"
,
# sequence projection
"//layers:misc_layers"
,
# sequence projection
"//layers:normalization_layers"
,
# sequence projection
"//layers:quantization_layers"
,
# sequence projection
# "//tf_ops:tf_custom_ops"
# sequence projection
"//tf_ops:tf_custom_ops_py"
,
# sequence projection
],
)
...
...
@@ -79,11 +79,11 @@ py_library(
deps
=
[
# package absl/logging
# package tensorflow
"//layers:base_layers"
,
"//layers:embedding_layers"
,
"//layers:transformer_layers"
,
# "//tf_ops:tf_custom_ops"
,
"//tf_ops:tf_custom_ops_py"
,
"//layers:base_layers"
,
# sequence projection
"//layers:embedding_layers"
,
# sequence projection
"//layers:transformer_layers"
,
# sequence projection
# "//tf_ops:tf_custom_ops"
# sequence projection
"//tf_ops:tf_custom_ops_py"
,
# sequence projection
],
)
...
...
@@ -93,13 +93,13 @@ py_library(
srcs_version
=
"PY3"
,
deps
=
[
# package absl/logging
# package tensor2tensor/utils:beam_search
# package tensorflow
# tensor2tensor/utils:beam_search",
"//layers:base_layers"
,
"//layers:embedding_layers"
,
"//layers:misc_layers"
,
"//layers:transformer_layers"
,
"//tf_ops:tf_custom_ops"
,
"//tf_ops:tf_custom_ops_py"
,
"//layers:base_layers"
,
# sequence projection
"//layers:embedding_layers"
,
# sequence projection
"//layers:misc_layers"
,
# sequence projection
"//layers:transformer_layers"
,
# sequence projection
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py"
,
# sequence projection
],
)
research/seq_flow_lite/models/byteqrnn.py
浏览文件 @
20cc2190
...
...
@@ -33,11 +33,11 @@ Sample model params:
from
absl
import
logging
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
dense_layers
from
layers
import
embedding_layers
from
layers
import
misc_layers
from
layers
import
qrnn_layers
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
dense_layers
# import seq_flow_lite module
from
layers
import
embedding_layers
# import seq_flow_lite module
from
layers
import
misc_layers
# import seq_flow_lite module
from
layers
import
qrnn_layers
# import seq_flow_lite module
class
Encoder
(
tf
.
keras
.
layers
.
Layer
):
...
...
research/seq_flow_lite/models/charformer.py
浏览文件 @
20cc2190
...
...
@@ -16,13 +16,13 @@
from
absl
import
logging
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
dense_layers
from
layers
import
embedding_layers
from
layers
import
misc_layers
from
layers
import
normalization_layers
from
layers
import
quantization_layers
from
models
import
transformer_encoder
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
dense_layers
# import seq_flow_lite module
from
layers
import
embedding_layers
# import seq_flow_lite module
from
layers
import
misc_layers
# import seq_flow_lite module
from
layers
import
normalization_layers
# import seq_flow_lite module
from
layers
import
quantization_layers
# import seq_flow_lite module
from
models
import
transformer_encoder
# import seq_flow_lite module
class
Encoder
(
tf
.
keras
.
layers
.
Layer
):
...
...
research/seq_flow_lite/models/pqrnn.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Implementation of pQRNN model."""
from
absl
import
logging
...
...
@@ -43,7 +42,7 @@ class Encoder(tf.keras.layers.Layer):
_get_params
(
"qrnn_kernel_width"
,
3
)
_get_params
(
"qrnn_zoneout_probability"
)
_get_params
(
"number_qrnn_layers"
)
_get_params
(
"labels"
)
_get_params
(
"labels"
,
[]
)
_get_params
(
"regularizer_scale"
)
_get_params
(
"quantize"
)
...
...
@@ -66,11 +65,12 @@ class Encoder(tf.keras.layers.Layer):
self
.
attention_pool
=
misc_layers
.
AttentionPooling
(
parameters
=
self
.
parameters
)
self
.
final_fc
=
dense_layers
.
BaseQDense
(
units
=
self
.
num_classes
,
rank
=
2
,
parameters
=
self
.
parameters
,
activation
=
None
)
if
self
.
num_classes
:
self
.
final_fc
=
dense_layers
.
BaseQDense
(
units
=
self
.
num_classes
,
rank
=
2
,
parameters
=
self
.
parameters
,
activation
=
None
)
def
call
(
self
,
projection
,
seq_length
):
mask
=
tf
.
sequence_mask
(
...
...
@@ -82,7 +82,11 @@ class Encoder(tf.keras.layers.Layer):
bottleneck
=
self
.
bottleneck_layer
(
projection
,
maskr3
,
inverse_normalizer
)
outputs
=
self
.
qrnn_stack
(
bottleneck
,
maskr3
,
inverse_normalizer
)
pre_logits
=
self
.
attention_pool
(
outputs
,
maskr3
,
inverse_normalizer
)
return
self
.
final_fc
(
pre_logits
)
if
self
.
num_classes
:
return
self
.
final_fc
(
pre_logits
)
else
:
return
pre_logits
class
Model
(
Encoder
):
...
...
research/seq_flow_lite/models/prado.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Implementation of PRADO model."""
import
copy
...
...
research/seq_flow_lite/models/sgnn/sgnn_test.py
浏览文件 @
20cc2190
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for seq_flow_lite.sgnn."""
import
tensorflow
as
tf
...
...
research/seq_flow_lite/models/transformer_encoder.py
浏览文件 @
20cc2190
...
...
@@ -18,8 +18,8 @@
from
absl
import
logging
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
transformer_layers
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
transformer_layers
# import seq_flow_lite module
class
Model
(
tf
.
keras
.
layers
.
Layer
):
...
...
research/seq_flow_lite/models/transformer_uniform_attn_decoder.py
浏览文件 @
20cc2190
...
...
@@ -20,12 +20,12 @@ from absl import logging
from
tensor2tensor.utils
import
beam_search
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
dense_layers
from
layers
import
embedding_layers
from
layers
import
normalization_layers
from
layers
import
quantization_layers
from
layers
import
transformer_layers
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
dense_layers
# import seq_flow_lite module
from
layers
import
embedding_layers
# import seq_flow_lite module
from
layers
import
normalization_layers
# import seq_flow_lite module
from
layers
import
quantization_layers
# import seq_flow_lite module
from
layers
import
transformer_layers
# import seq_flow_lite module
class
TransformerUniformAttnDecoder
(
base_layers
.
BaseLayer
):
...
...
research/seq_flow_lite/tf_ops/BUILD
浏览文件 @
20cc2190
...
...
@@ -11,20 +11,23 @@ package(
)
cc_library
(
name
=
"sequence_string_projection_op"
,
srcs
=
[
"sequence_string_projection.cc"
,
name
=
"projection_normalizer_util"
,
srcs
=
[
"projection_normalizer_util.cc"
],
hdrs
=
[
"projection_normalizer_util.h"
],
deps
=
[
":projection_util"
,
"@utf_archive//:utf"
,
],
)
cc_library
(
name
=
"projection_tokenizer_util"
,
srcs
=
[
"projection_tokenizer_util.cc"
],
hdrs
=
[
"projection_tokenizer_util.h"
],
deps
=
[
":projection_normalizer_util"
,
":projection_tokenizer_util"
,
":projection_util"
,
":text_distorter"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@tensorflow_includes//:includes"
,
"@tensorflow_solib//:framework_lib"
,
"@utf_archive//:utf"
,
],
alwayslink
=
1
,
)
cc_library
(
...
...
@@ -37,22 +40,46 @@ cc_library(
)
cc_library
(
name
=
"
projection_tokenizer_util
"
,
srcs
=
[
"
projection_tokenizer_util
.cc"
],
hdrs
=
[
"
projection_tokenizer_util
.h"
],
name
=
"
skipgram_finder
"
,
srcs
=
[
"
skipgram_finder
.cc"
],
hdrs
=
[
"
skipgram_finder
.h"
],
deps
=
[
":projection_util"
,
"@utf_archive//:utf"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@com_google_absl//absl/container:flat_hash_set"
,
"@com_google_absl//absl/strings"
,
"@icu4c//:icu4c"
,
],
)
cc_test
(
name
=
"skipgram_finder_test"
,
srcs
=
[
"skipgram_finder_test.cc"
],
deps
=
[
":skipgram_finder"
,
"@com_google_absl//absl/strings"
,
"@com_google_googletest//:gtest_main"
,
"@icu4c//:icu4c"
,
],
)
cc_library
(
name
=
"
projection_normalizer_util
"
,
srcs
=
[
"
projection_normalizer_util
.cc"
],
hdrs
=
[
"
projection_normalizer_util
.h"
],
name
=
"
subsequence_finder
"
,
srcs
=
[
"
subsequence_finder
.cc"
],
hdrs
=
[
"
subsequence_finder
.h"
],
deps
=
[
":projection_util"
,
"@utf_archive//:utf"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@com_google_absl//absl/container:flat_hash_set"
,
"@com_google_absl//absl/strings"
,
"@icu4c//:icu4c"
,
],
)
cc_test
(
name
=
"subsequence_finder_test"
,
srcs
=
[
"subsequence_finder_test.cc"
],
deps
=
[
":subsequence_finder"
,
"@com_google_googletest//:gtest_main"
,
],
)
...
...
@@ -67,6 +94,55 @@ cc_library(
],
)
cc_library
(
name
=
"denylist_op"
,
srcs
=
[
"denylist_op.cc"
],
deps
=
[
":skipgram_finder"
,
":subsequence_finder"
,
"@com_google_absl//absl/cleanup"
,
"@com_google_absl//absl/container:flat_hash_set"
,
"@com_google_absl//absl/memory"
,
"@tensorflow_includes//:includes"
,
"@tensorflow_solib//:framework_lib"
,
],
alwayslink
=
1
,
)
gen_op_wrapper_py
(
name
=
"denylist_op_py"
,
out
=
"denylist_op.py"
,
kernel_lib
=
":denylist_op"
,
)
py_test
(
name
=
"denylist_op_py_test"
,
srcs
=
[
"denylist_op_test.py"
],
main
=
"denylist_op_test.py"
,
python_version
=
"PY3"
,
srcs_version
=
"PY3"
,
deps
=
[
":denylist_op_py"
,
],
)
cc_library
(
name
=
"sequence_string_projection_op"
,
srcs
=
[
"sequence_string_projection.cc"
,
],
deps
=
[
":projection_normalizer_util"
,
":projection_tokenizer_util"
,
":projection_util"
,
":text_distorter"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@tensorflow_includes//:includes"
,
"@tensorflow_solib//:framework_lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"sequence_string_projection_test"
,
size
=
"small"
,
...
...
@@ -78,6 +154,12 @@ cc_test(
],
)
gen_op_wrapper_py
(
name
=
"sequence_string_projection_op_py"
,
out
=
"sequence_string_projection_op.py"
,
kernel_lib
=
":sequence_string_projection_op"
,
)
cc_library
(
name
=
"sequence_string_projection_op_v2"
,
srcs
=
[
...
...
@@ -111,12 +193,6 @@ gen_op_wrapper_py(
kernel_lib
=
":sequence_string_projection_op_v2"
,
)
gen_op_wrapper_py
(
name
=
"sequence_string_projection_op_py"
,
out
=
"sequence_string_projection_op.py"
,
kernel_lib
=
":sequence_string_projection_op"
,
)
cc_library
(
name
=
"tf_custom_ops"
,
srcs
=
[
"tf_custom_ops.cc"
],
...
...
research/seq_flow_lite/tf_ops/denylist_op.cc
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
namespace
seq_flow_lite
{
using
::
tensorflow
::
OpKernel
;
using
::
tensorflow
::
OpKernelConstruction
;
using
::
tensorflow
::
OpKernelContext
;
using
::
tensorflow
::
Status
;
using
::
tensorflow
::
Tensor
;
using
::
tensorflow
::
TensorShape
;
using
::
tensorflow
::
errors
::
InvalidArgument
;
using
::
tensorflow
::
shape_inference
::
InferenceContext
;
using
::
tensorflow
::
shape_inference
::
ShapeHandle
;
// Description of the outputs and attributes for the Denylist ops.
const
char
kDescription
[]
=
R"(
output: A floating point tensor that contains a prediction vector for each
input string. The vector will either be:
* [1, 1, ..., 0, 0, ...] if no denylisted skipgrams are found.
(All negative categories are 1.0 and all positive categories are 0.0.)
* an indicator vector if any denylisted skipgrams are found.
(0.0 if no skipgrams belonging to the category were found and 1.0 otherwise)
max_skip_size: The maximum number of tokens that can be skipped when generating
skipgrams.
denylist: A string vector containing denylisted skipgrams.
denylist_category: An int32 vector containing the category of the corresponding
skipgram in the denylist.
categories: An int32 scalar. This is the total number of categories.
All categories in denylist_category must be in [0, categories).
negative_categories: An int32 scalar. The total number of categories that
should be set if no entries in the denylist are triggered. These
negative categories are assumed to be [0, negative_categories).
)"
;
// The base class for all Denylist ops. It does two things:
// 1) It defines the output tensor of the op and it defines the attributes
// needed to specify the denylist and convert denylist categories into
// output vectors.
// 2) It defines a Compute() function. The compute function is responsible
// for filling in the output tensor, while the subclass is responsible
// for processing the input.
class
DenylistOpBase
:
public
OpKernel
{
public:
explicit
DenylistOpBase
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"categories"
,
&
categories_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"negative_categories"
,
&
negative_categories_
));
OP_REQUIRES
(
context
,
categories_
>
0
,
InvalidArgument
(
"Number of categories ("
,
categories_
,
") must be positive."
));
OP_REQUIRES
(
context
,
negative_categories_
>=
0
,
InvalidArgument
(
"Number of negative_categories ("
,
negative_categories_
,
") must be non-negative."
));
OP_REQUIRES
(
context
,
negative_categories_
<
categories_
,
InvalidArgument
(
"Number of categories ("
,
categories_
,
") must be greater than the "
"number of negative_categories ("
,
negative_categories_
,
")."
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"max_skip_size"
,
&
max_skip_size_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"denylist"
,
&
denylist_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"denylist_category"
,
&
denylist_category_
));
OP_REQUIRES
(
context
,
denylist_
.
size
()
==
denylist_category_
.
size
(),
InvalidArgument
(
"denylist length ("
,
denylist_
.
size
(),
") != denylist_category length ("
,
denylist_category_
.
size
(),
")"
));
int
max
=
*
std
::
max_element
(
denylist_category_
.
begin
(),
denylist_category_
.
end
());
OP_REQUIRES
(
context
,
max
<
categories_
,
InvalidArgument
(
"max element of denylist_category ("
,
max
,
") >= categories ("
,
categories_
,
")"
));
int
min
=
*
std
::
min_element
(
denylist_category_
.
begin
(),
denylist_category_
.
end
());
OP_REQUIRES
(
context
,
min
>=
0
,
InvalidArgument
(
"min element of denylist_category ("
,
min
,
") < 0"
));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
auto
compute_context
=
InitializeComputeContext
(
context
);
if
(
compute_context
==
nullptr
)
{
return
;
}
auto
context_cleaner
=
absl
::
MakeCleanup
([
this
,
compute_context
]
{
this
->
FinalizeComputeContext
(
compute_context
);
});
Tensor
*
output_tensor
;
TensorShape
output_shape
=
InputStringsShape
(
compute_context
);
output_shape
.
AddDim
(
categories_
);
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
"output"
,
output_shape
,
&
output_tensor
));
auto
output_values
=
output_tensor
->
flat
<
float
>
();
for
(
int
i
=
0
;
i
<
NumInputStrings
(
compute_context
);
i
++
)
{
auto
category
=
GetCategories
(
i
,
compute_context
);
int
base_index
=
i
*
categories_
;
if
(
category
.
empty
())
{
for
(
int
j
=
0
;
j
<
categories_
;
j
++
)
{
output_values
(
base_index
+
j
)
=
j
<
negative_categories_
?
1.0
:
0.0
;
}
}
else
{
for
(
int
j
=
0
;
j
<
categories_
;
j
++
)
{
output_values
(
base_index
+
j
)
=
category
.
contains
(
j
)
?
1.0
:
0.0
;
}
}
}
}
protected:
int
max_skip_size
()
{
return
max_skip_size_
;
}
int
denylist_size
()
{
return
denylist_
.
size
();
}
const
std
::
string
&
denylist
(
int
i
)
{
return
denylist_
[
i
];
}
int32_t
denylist_category
(
int
i
)
{
return
denylist_category_
[
i
];
}
private:
// Called at the beginning of Compute(). This function should process
// the input and return a context object that can be used to identify
// the denylist categories of each input string.
virtual
void
*
InitializeComputeContext
(
OpKernelContext
*
context
)
=
0
;
// Called at the end of Compute(). Frees the context object.
virtual
void
FinalizeComputeContext
(
void
*
context
)
=
0
;
// Returns the shape of the input tensor, if it only consisted of strings.
// If the input tensor is strings, this is the shape of the input tensor.
// If the input tensor is tokens, this is the shape of the input tensor,
// minus the innermost dimension.
virtual
TensorShape
InputStringsShape
(
void
*
context
)
=
0
;
// Returns the number of strings in the input tensor.
virtual
int
NumInputStrings
(
void
*
context
)
=
0
;
// Returns the denylist categories of the index-th string.
virtual
absl
::
flat_hash_set
<
int
>
GetCategories
(
int
index
,
void
*
context
)
=
0
;
int32_t
categories_
;
int32_t
negative_categories_
;
int
max_skip_size_
;
std
::
vector
<
std
::
string
>
denylist_
;
std
::
vector
<
int32_t
>
denylist_category_
;
};
// A base class for Denylist ops that expect a string tensor input.
class
StringDenylistOp
:
public
DenylistOpBase
{
public:
explicit
StringDenylistOp
(
OpKernelConstruction
*
context
)
:
DenylistOpBase
(
context
)
{}
private:
void
*
InitializeComputeContext
(
OpKernelContext
*
context
)
override
{
const
Tensor
*
input_tensor
;
auto
status
=
context
->
input
(
"input"
,
&
input_tensor
);
if
(
!
status
.
ok
())
{
context
->
CtxFailureWithWarning
(
__FILE__
,
__LINE__
,
status
);
return
nullptr
;
}
return
new
ComputeContext
(
input_tensor
);
}
void
FinalizeComputeContext
(
void
*
context
)
override
{
delete
static_cast
<
ComputeContext
*>
(
context
);
}
TensorShape
InputStringsShape
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
input_tensor
->
shape
();
}
int
NumInputStrings
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
input_tensor_values
.
size
();
}
absl
::
flat_hash_set
<
int
>
GetCategories
(
int
index
,
void
*
context
)
override
{
return
FindTerms
(
static_cast
<
ComputeContext
*>
(
context
)
->
input_tensor_values
(
index
));
}
struct
ComputeContext
{
ComputeContext
(
const
Tensor
*
input_tensor
)
:
input_tensor
(
input_tensor
),
input_tensor_values
(
input_tensor
->
flat
<::
tensorflow
::
tstring
>
())
{}
const
Tensor
*
input_tensor
;
::
tensorflow
::
TTypes
<::
tensorflow
::
tstring
>::
ConstFlat
input_tensor_values
;
};
// Returns the set of denylist categories for the input string.
virtual
absl
::
flat_hash_set
<
int
>
FindTerms
(
const
std
::
string
&
input
)
=
0
;
};
// A denylist op that uses the SkipgramFinder on string inputs.
class
SkipgramDenylistOp
:
public
StringDenylistOp
{
public:
explicit
SkipgramDenylistOp
(
OpKernelConstruction
*
context
)
:
StringDenylistOp
(
context
)
{
skipgram_finder_
=
std
::
make_unique
<
SkipgramFinder
>
(
max_skip_size
());
for
(
int
i
=
0
;
i
<
denylist_size
();
i
++
)
{
skipgram_finder_
->
AddSkipgram
(
denylist
(
i
),
denylist_category
(
i
));
}
}
private:
absl
::
flat_hash_set
<
int
>
FindTerms
(
const
std
::
string
&
input
)
override
{
return
skipgram_finder_
->
FindSkipgrams
(
input
);
}
std
::
unique_ptr
<
SkipgramFinder
>
skipgram_finder_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"SkipgramDenylist"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
SkipgramDenylistOp
);
// Shape inference function for Denylist ops with string inputs.
Status
StringDenylistShapeFn
(
InferenceContext
*
context
)
{
int32_t
categories
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"categories"
,
&
categories
));
ShapeHandle
output_shape
;
TF_RETURN_IF_ERROR
(
context
->
Concatenate
(
context
->
input
(
0
),
context
->
MakeShape
({
categories
}),
&
output_shape
));
context
->
set_output
(
0
,
output_shape
);
return
::
tensorflow
::
Status
::
OK
();
}
REGISTER_OP
(
"SkipgramDenylist"
)
.
Input
(
"input: string"
)
.
Output
(
"output: float"
)
.
Attr
(
"max_skip_size: int"
)
.
Attr
(
"denylist: list(string)"
)
.
Attr
(
"denylist_category: list(int)"
)
.
Attr
(
"categories: int"
)
.
Attr
(
"negative_categories: int"
)
.
SetShapeFn
(
StringDenylistShapeFn
)
.
Doc
(
absl
::
StrCat
(
"Generates dense prediction vectors for input strings "
"using a skipgram denylist."
,
"
\n\n
"
,
"input: A string tensor."
,
"
\n\n
"
,
kDescription
));
// A Denylist op that uses the SubsequenceFinder on string inputs.
class
SubsequenceDenylistOp
:
public
StringDenylistOp
{
public:
explicit
SubsequenceDenylistOp
(
OpKernelConstruction
*
context
)
:
StringDenylistOp
(
context
)
{
subsequence_finder_
=
std
::
make_unique
<
SubsequenceFinder
>
(
max_skip_size
());
for
(
int
i
=
0
;
i
<
denylist_size
();
i
++
)
{
subsequence_finder_
->
AddSubsequence
(
denylist
(
i
),
denylist_category
(
i
));
}
}
private:
absl
::
flat_hash_set
<
int
>
FindTerms
(
const
std
::
string
&
input
)
override
{
return
subsequence_finder_
->
FindSubsequences
(
input
);
}
std
::
unique_ptr
<
SubsequenceFinder
>
subsequence_finder_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"SubsequenceDenylist"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
SubsequenceDenylistOp
);
REGISTER_OP
(
"SubsequenceDenylist"
)
.
Input
(
"input: string"
)
.
Output
(
"output: float"
)
.
Attr
(
"max_skip_size: int"
)
.
Attr
(
"denylist: list(string)"
)
.
Attr
(
"denylist_category: list(int)"
)
.
Attr
(
"categories: int"
)
.
Attr
(
"negative_categories: int"
)
.
SetShapeFn
(
StringDenylistShapeFn
)
.
Doc
(
absl
::
StrCat
(
"Generates dense prediction vectors for inputs using a "
"subsequence denylist."
,
"
\n\n
"
,
"input: A string tensor."
,
"
\n\n
"
,
kDescription
));
// A denylist op that uses the SkipgramFinder on tokenized string inputs.
// The inputs are a pair of tensors: a token tensor of type string and
// a token count tensor of type T.
template
<
typename
T
>
class
TokenizedDenylistOp
:
public
DenylistOpBase
{
public:
explicit
TokenizedDenylistOp
(
OpKernelConstruction
*
context
)
:
DenylistOpBase
(
context
)
{
skipgram_finder_
=
std
::
make_unique
<
SkipgramFinder
>
(
max_skip_size
());
for
(
int
i
=
0
;
i
<
denylist_size
();
i
++
)
{
skipgram_finder_
->
AddSkipgram
(
denylist
(
i
),
denylist_category
(
i
));
}
}
private:
void
*
InitializeComputeContext
(
OpKernelContext
*
context
)
override
{
const
Tensor
*
input_tensor
;
{
auto
status
=
context
->
input
(
"input"
,
&
input_tensor
);
if
(
!
status
.
ok
())
{
context
->
CtxFailureWithWarning
(
__FILE__
,
__LINE__
,
status
);
return
nullptr
;
}
}
const
Tensor
*
token_count_tensor
;
{
auto
status
=
context
->
input
(
"token_count"
,
&
token_count_tensor
);
if
(
!
status
.
ok
())
{
context
->
CtxFailureWithWarning
(
__FILE__
,
__LINE__
,
status
);
return
nullptr
;
}
}
return
new
ComputeContext
(
input_tensor
,
token_count_tensor
);
}
void
FinalizeComputeContext
(
void
*
context
)
override
{
delete
static_cast
<
ComputeContext
*>
(
context
);
}
TensorShape
InputStringsShape
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
shape
;
}
int
NumInputStrings
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
size
;
}
absl
::
flat_hash_set
<
int
>
GetCategories
(
int
index
,
void
*
x
)
override
{
ComputeContext
*
context
=
static_cast
<
ComputeContext
*>
(
x
);
int64_t
num_tokens
=
context
->
token_count_flat
(
index
);
std
::
vector
<
absl
::
string_view
>
tokens
;
tokens
.
reserve
(
num_tokens
);
int64_t
start
=
index
*
context
->
max_tokens
;
for
(
int64_t
i
=
start
;
i
<
start
+
num_tokens
;
i
++
)
{
tokens
.
emplace_back
(
context
->
token_flat
(
i
).
data
(),
context
->
token_flat
(
i
).
size
());
}
return
skipgram_finder_
->
FindSkipgrams
(
tokens
);
}
struct
ComputeContext
{
ComputeContext
(
const
Tensor
*
token_tensor
,
const
Tensor
*
token_count_tensor
)
:
token_flat
(
token_tensor
->
flat
<::
tensorflow
::
tstring
>
()),
token_count_flat
(
token_count_tensor
->
flat
<
T
>
())
{
shape
=
token_tensor
->
shape
();
max_tokens
=
shape
.
dim_size
(
shape
.
dims
()
-
1
);
shape
.
RemoveLastDims
(
1
);
size
=
1
;
for
(
int64_t
i
=
0
;
i
<
shape
.
dims
();
i
++
)
{
size
=
size
*
shape
.
dim_size
(
i
);
}
}
const
typename
::
tensorflow
::
TTypes
<::
tensorflow
::
tstring
>::
ConstFlat
token_flat
;
const
typename
::
tensorflow
::
TTypes
<
T
>::
ConstFlat
token_count_flat
;
TensorShape
shape
;
int64_t
size
;
int64_t
max_tokens
;
};
std
::
unique_ptr
<
SkipgramFinder
>
skipgram_finder_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"TokenizedDenylist"
)
.
Device
(
::
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int32_t
>
(
"Ttoken_count"
),
TokenizedDenylistOp
<
int32_t
>
);
REGISTER_KERNEL_BUILDER
(
Name
(
"TokenizedDenylist"
)
.
Device
(
::
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int64_t
>
(
"Ttoken_count"
),
TokenizedDenylistOp
<
int64_t
>
);
// Shape inference function for Denylist ops with tokenized string inputs.
Status
TokenizedDenylistShapeFn
(
InferenceContext
*
context
)
{
int32_t
categories
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"categories"
,
&
categories
));
ShapeHandle
string_tensor_shape
;
TF_RETURN_IF_ERROR
(
context
->
Subshape
(
context
->
input
(
0
),
0
,
-
1
,
&
string_tensor_shape
));
ShapeHandle
output_shape
;
TF_RETURN_IF_ERROR
(
context
->
Concatenate
(
string_tensor_shape
,
context
->
MakeShape
({
categories
}),
&
output_shape
));
context
->
set_output
(
0
,
output_shape
);
return
::
tensorflow
::
Status
::
OK
();
}
REGISTER_OP
(
"TokenizedDenylist"
)
.
Input
(
"input: string"
)
.
Input
(
"token_count: Ttoken_count"
)
.
Output
(
"output: float"
)
.
Attr
(
"max_skip_size: int"
)
.
Attr
(
"denylist: list(string)"
)
.
Attr
(
"denylist_category: list(int)"
)
.
Attr
(
"categories: int"
)
.
Attr
(
"negative_categories: int"
)
.
Attr
(
"Ttoken_count: {int32, int64}"
)
.
SetShapeFn
(
TokenizedDenylistShapeFn
)
.
Doc
(
absl
::
StrCat
(
"Generates dense prediction vectors for tokens using a "
"skipgram denylist."
,
"
\n\n
"
,
"input: A string tensor of tokens."
,
"
\n\n
"
,
kDescription
));
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/denylist_op_test.cc
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.proto.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace
seq_flow_lite
{
namespace
{
using
::
tensorflow
::
DT_FLOAT
;
using
::
tensorflow
::
DT_INT32
;
using
::
tensorflow
::
DT_INT64
;
using
::
tensorflow
::
DT_STRING
;
using
::
tensorflow
::
NodeDefBuilder
;
using
::
tensorflow
::
OpsTestBase
;
using
::
tensorflow
::
Tensor
;
using
::
tensorflow
::
TensorShape
;
using
::
tensorflow
::
errors
::
InvalidArgument
;
using
::
tensorflow
::
test
::
ExpectTensorEqual
;
using
::
tensorflow
::
test
::
FillValues
;
class
SkipgramDenylistOpTest
:
public
OpsTestBase
{};
TEST_F
(
SkipgramDenylistOpTest
,
Correct
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
}),
{
"q a q b q c q"
,
"q a b q q c"
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
SkipgramDenylistOpTest
,
Prefix
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b.* c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
}),
{
"q a q bq q c q"
,
"q a bq q q c"
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
SkipgramDenylistOpTest
,
ZeroCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
0
)
.
Attr
(
"negative_categories"
,
0
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (0) must be positive."
));
}
TEST_F
(
SkipgramDenylistOpTest
,
NegativeCategoriesLessThanZero
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
-
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of negative_categories (-1) must be non-negative."
));
}
TEST_F
(
SkipgramDenylistOpTest
,
CategoriesEqualNegativeCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (1) must be greater than the "
"number of negative_categories (1)."
));
}
class
SubsequenceDenylistOpTest
:
public
OpsTestBase
{};
TEST_F
(
SubsequenceDenylistOpTest
,
Correct
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
}),
{
"qaqbqcq"
,
"qabqqc"
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
SubsequenceDenylistOpTest
,
ZeroCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
0
)
.
Attr
(
"negative_categories"
,
0
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (0) must be positive."
));
}
TEST_F
(
SubsequenceDenylistOpTest
,
NegativeCategoriesLessThanZero
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
-
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of negative_categories (-1) must be non-negative."
));
}
TEST_F
(
SubsequenceDenylistOpTest
,
CategoriesEqualNegativeCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (1) must be greater than the "
"number of negative_categories (1)."
));
}
class
TokenizedDenylistOpTest
:
public
OpsTestBase
{};
TEST_F
(
TokenizedDenylistOpTest
,
CorrectInt64TokenCount
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
,
7
}),
{
"q"
,
"a"
,
"q"
,
"b"
,
"q"
,
"c"
,
"q"
,
//
"q"
,
"a"
,
"b"
,
"q"
,
"q"
,
"c"
,
""
});
AddInputFromArray
<
int64_t
>
(
TensorShape
({
2
}),
{
7
,
6
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
TokenizedDenylistOpTest
,
CorrectInt32TokenCount
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT32
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
,
7
}),
{
"q"
,
"a"
,
"q"
,
"b"
,
"q"
,
"c"
,
"q"
,
//
"q"
,
"a"
,
"b"
,
"q"
,
"q"
,
"c"
,
""
});
AddInputFromArray
<
int32_t
>
(
TensorShape
({
2
}),
{
7
,
6
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
TokenizedDenylistOpTest
,
ZeroCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
0
)
.
Attr
(
"negative_categories"
,
0
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (0) must be positive."
));
}
TEST_F
(
TokenizedDenylistOpTest
,
NegativeCategoriesLessThanZero
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
-
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of negative_categories (-1) must be non-negative."
));
}
TEST_F
(
TokenizedDenylistOpTest
,
CategoriesEqualNegativeCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (1) must be greater than the "
"number of negative_categories (1)."
));
}
}
// namespace
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/denylist_op_test.py
0 → 100644
浏览文件 @
20cc2190
# Copyright 2022 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test denylist op and show example usage from python wrapper."""
import
tensorflow
as
tf
from
tf_ops
import
denylist_op
# import seq_flow_lite module
class
SkipgramDenylistTest
(
tf
.
test
.
TestCase
):
def
test_correct
(
self
):
result
=
denylist_op
.
skipgram_denylist
(
input
=
[
"q a q b q c q"
,
"q a b q q c"
],
max_skip_size
=
1
,
denylist
=
[
"a b c"
],
denylist_category
=
[
1
],
categories
=
2
,
negative_categories
=
1
)
self
.
assertAllEqual
(
result
,
[[
0.0
,
1.0
],
[
1.0
,
0.0
]])
class
SubsequenceDenylistTest
(
tf
.
test
.
TestCase
):
def
test_correct
(
self
):
result
=
denylist_op
.
subsequence_denylist
(
input
=
[
"qaqbqcq"
,
"qabqqc"
],
max_skip_size
=
1
,
denylist
=
[
"a b c"
],
denylist_category
=
[
1
],
categories
=
2
,
negative_categories
=
1
)
self
.
assertAllEqual
(
result
,
[[
0.0
,
1.0
],
[
1.0
,
0.0
]])
class
TokenizedDenylistTest
(
tf
.
test
.
TestCase
):
def
test_correct
(
self
):
result
=
denylist_op
.
tokenized_denylist
(
input
=
[[
"q"
,
"a"
,
"q"
,
"b"
,
"q"
,
"c"
,
"q"
],
[
"q"
,
"a"
,
"b"
,
"q"
,
"q"
,
"c"
,
""
]],
token_count
=
[
7
,
6
],
max_skip_size
=
1
,
denylist
=
[
"a b c"
],
denylist_category
=
[
1
],
categories
=
2
,
negative_categories
=
1
)
self
.
assertAllEqual
(
result
,
[[
0.0
,
1.0
],
[
1.0
,
0.0
]])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/seq_flow_lite/tf_ops/projection_normalizer_util.cc
浏览文件 @
20cc2190
...
...
@@ -26,7 +26,7 @@ limitations under the License.
bool
IsDigit
(
const
std
::
string
&
text
)
{
Rune
rune
;
for
(
size_t
i
=
0
;
i
<
text
.
length
();)
{
const
int
bytes_read
=
chartorune
(
&
rune
,
const_cast
<
char
*>
(
text
.
data
()));
const
int
bytes_read
=
chartorune
(
&
rune
,
const_cast
<
char
*>
(
text
.
data
()));
if
(
rune
==
Runeerror
||
bytes_read
==
0
)
break
;
if
(
rune
>=
static_cast
<
Rune
>
(
'0'
)
&&
rune
<=
static_cast
<
Rune
>
(
'9'
))
{
return
true
;
...
...
@@ -98,6 +98,29 @@ std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
return
token
;
}
void
NormalizeSpaces
(
std
::
string
&
input
)
{
// Whether to copy the next character if it's a space.
bool
copy_space
=
false
;
size_t
j
=
0
;
for
(
size_t
i
=
0
;
i
<
input
.
length
();
++
i
)
{
if
(
input
[
i
]
==
' '
)
{
if
(
!
copy_space
)
continue
;
copy_space
=
false
;
}
else
{
copy_space
=
true
;
}
if
(
j
!=
i
)
{
input
[
j
]
=
input
[
i
];
}
++
j
;
}
if
(
j
>
0
&&
input
[
j
-
1
]
==
' '
)
{
--
j
;
}
input
.
resize
(
j
);
}
void
ProjectionNormalizer
::
InitializeSeparators
(
const
std
::
string
&
separators
)
{
for
(
size_t
i
=
0
;
i
<
separators
.
length
();
++
i
)
{
if
(
separators
[
i
]
!=
' '
)
{
...
...
@@ -150,9 +173,14 @@ std::string ProjectionNormalizer::Normalize(const char* input_ptr, size_t len,
normalized
=
ContractToken
(
normalized
.
data
(),
normalized
.
length
(),
3
);
}
if
(
normalize_spaces_
)
{
NormalizeSpaces
(
normalized
);
}
if
(
!
separators_
.
empty
())
{
// Add space around separators_.
normalized
=
NormalizeInternal
(
normalized
.
data
(),
normalized
.
length
());
}
return
normalized
;
}
research/seq_flow_lite/tf_ops/projection_normalizer_util.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#include <string>
#include <unordered_set>
...
...
@@ -24,14 +24,17 @@ limitations under the License.
// Normalizes the input with the given |separators| by adding a space before and
// after each separator. When |normalize_repetition| is true, it removes the
// repeated characters (except numbers) which consecutively appeared more than
// twice in a word.
// twice in a word. When |normalize_spaces| is true, it removes spaces from
// the beginning and ending of the input, as well as repeated spaces.
// Examples: arwwwww -> arww, good!!!!! -> good!!, hahaha => haha.
class
ProjectionNormalizer
{
public:
explicit
ProjectionNormalizer
(
const
std
::
string
&
separators
,
bool
normalize_repetition
=
false
)
{
bool
normalize_repetition
=
false
,
bool
normalize_spaces
=
false
)
:
normalize_repetition_
(
normalize_repetition
),
normalize_spaces_
(
normalize_spaces
)
{
InitializeSeparators
(
separators
);
normalize_repetition_
=
normalize_repetition
;
}
// Normalizes the repeated characters (except numbers) which consecutively
...
...
@@ -49,6 +52,7 @@ class ProjectionNormalizer {
std
::
unordered_set
<
char
>
separators_
;
bool
normalize_repetition_
;
bool
normalize_spaces_
;
};
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
research/seq_flow_lite/tf_ops/projection_tokenizer_util.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#include <string>
#include <unordered_set>
...
...
@@ -55,4 +55,4 @@ class ProjectionTokenizer {
std
::
unordered_set
<
char
>
separators_
;
};
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
research/seq_flow_lite/tf_ops/projection_util.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_UTIL_H_
#include <memory>
#include <string>
#include <unordered_map>
...
...
@@ -156,4 +156,4 @@ std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
std
::
string
JoinPairsBySpace
(
std
::
vector
<
std
::
pair
<
const
char
*
,
size_t
>>
words
);
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_UTIL_H_
research/seq_flow_lite/tf_ops/sequence_string_projection.cc
浏览文件 @
20cc2190
...
...
@@ -109,11 +109,14 @@ class SequenceStringProjectionOp : public OpKernel {
bool
normalize_repetition
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"normalize_repetition"
,
&
normalize_repetition
));
bool
normalize_spaces
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"normalize_spaces"
,
&
normalize_spaces
));
std
::
string
separators
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"token_separators"
,
&
separators
));
if
(
!
separators
.
empty
()
||
normalize_repetition
)
{
if
(
!
separators
.
empty
()
||
normalize_repetition
||
normalize_spaces
)
{
projection_normalizer_
=
absl
::
make_unique
<
ProjectionNormalizer
>
(
separators
,
normalize_repetition
);
separators
,
normalize_repetition
,
normalize_spaces
);
}
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"add_first_cap_feature"
,
...
...
@@ -326,6 +329,7 @@ REGISTER_OP("SequenceStringProjection")
.
Attr
(
"split_on_space: bool = True"
)
.
Attr
(
"token_separators: string = ''"
)
.
Attr
(
"normalize_repetition: bool = false"
)
.
Attr
(
"normalize_spaces: bool = false"
)
.
SetShapeFn
([](
InferenceContext
*
c
)
{
DimensionHandle
size
;
...
...
@@ -384,6 +388,10 @@ Attribute(s):
- add_all_caps_feature: Specifies the probability with which a feature to the
resulting projection tensor that helps discriminate if the input token is
ALLCAPS will be added.
- normalize_repetition: When true normalizes repetition in text tokens before
fingerprinting.
- normalize_spaces: When true strips leading and trailing spaces and removes
repeated spaces.
Output(s):
- projection: Floating point tensor with ternary values of shape
...
...
research/seq_flow_lite/tf_ops/skipgram_finder.cc
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <cctype>
#include <deque>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace
seq_flow_lite
{
namespace
{
void
PreprocessToken
(
std
::
string
&
token
)
{
char
*
s
=
const_cast
<
char
*>
(
token
.
data
());
int32_t
size
=
token
.
size
();
int32_t
in
=
0
;
int32_t
out
=
0
;
while
(
in
<
size
)
{
UChar32
c
;
int32_t
old_in
=
in
;
U8_NEXT
(
s
,
in
,
size
,
c
);
if
(
c
<
0
)
{
break
;
}
if
(
u_ispunct
(
c
))
continue
;
UChar32
cl
=
u_tolower
(
c
);
// This is a hack, but there are exactly two unicode characters whose
// lowercase versions have longer UTF-8 encodings (0x23a to 0x2c65,
// 0x23e to 0x2c66). So, to avoid sizing issues, they're not lowercased.
if
(
U8_LENGTH
(
cl
)
>
(
in
-
old_in
))
{
cl
=
c
;
}
U8_APPEND_UNSAFE
(
s
,
out
,
cl
);
}
size_t
remaining
=
token
.
size
()
-
in
;
if
(
remaining
>
0
)
{
memmove
(
s
+
out
,
s
+
in
,
remaining
);
out
+=
remaining
;
}
token
.
resize
(
out
);
}
}
// namespace
void
SkipgramFinder
::
AddSkipgram
(
absl
::
string_view
skipgram
,
int
category
)
{
std
::
vector
<
std
::
string
>
tokens
=
absl
::
StrSplit
(
skipgram
,
' '
);
// Store the skipgram in a trie-like structure that uses tokens as the
// edge labels, instead of characters. Each node represents a skipgram made
// from the tokens used to reach the node, and stores the categories the
// skipgram is associated with.
TrieNode
*
cur
=
&
skipgram_trie_
;
for
(
auto
&
token
:
tokens
)
{
if
(
absl
::
EndsWith
(
token
,
".*"
))
{
token
.
resize
(
token
.
size
()
-
2
);
PreprocessToken
(
token
);
auto
iter
=
cur
->
prefix_to_node
.
find
(
token
);
if
(
iter
!=
cur
->
prefix_to_node
.
end
())
{
cur
=
&
iter
->
second
;
}
else
{
cur
=
&
cur
->
prefix_to_node
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
token
),
std
::
make_tuple
<>
())
.
first
->
second
;
}
continue
;
}
PreprocessToken
(
token
);
auto
iter
=
cur
->
token_to_node
.
find
(
token
);
if
(
iter
!=
cur
->
token_to_node
.
end
())
{
cur
=
&
iter
->
second
;
}
else
{
cur
=
&
cur
->
token_to_node
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
token
),
std
::
make_tuple
<>
())
.
first
->
second
;
}
}
cur
->
categories
.
insert
(
category
);
}
absl
::
flat_hash_set
<
int
>
SkipgramFinder
::
FindSkipgrams
(
absl
::
string_view
input
)
const
{
std
::
vector
<
std
::
string
>
tokens
=
absl
::
StrSplit
(
input
,
' '
);
std
::
vector
<
absl
::
string_view
>
sv_tokens
;
sv_tokens
.
reserve
(
tokens
.
size
());
for
(
auto
&
token
:
tokens
)
{
PreprocessToken
(
token
);
sv_tokens
.
emplace_back
(
token
.
data
(),
token
.
size
());
}
return
FindSkipgrams
(
sv_tokens
);
}
absl
::
flat_hash_set
<
int
>
SkipgramFinder
::
FindSkipgrams
(
const
std
::
vector
<
absl
::
string_view
>&
tokens
)
const
{
absl
::
flat_hash_set
<
int
>
categories
;
// Tracks skipgram prefixes and the index of their last token.
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>
indices_and_skipgrams
;
for
(
int
token_i
=
0
;
token_i
<
tokens
.
size
();
token_i
++
)
{
const
absl
::
string_view
&
token
=
tokens
[
token_i
];
std
::
vector
<
absl
::
string_view
>
token_prefixes
;
{
const
char
*
s
=
token
.
data
();
int32_t
l
=
token
.
size
();
int32_t
n
=
0
;
while
(
n
<
l
)
{
int32_t
n_old
=
n
;
U8_FWD_1
(
s
,
n
,
l
);
if
(
n
==
n_old
)
break
;
token_prefixes
.
emplace_back
(
s
,
n
);
}
}
// Drop any skipgrams prefixes which would skip more than `max_skip_size_`
// tokens between the end of the prefix and the current token.
while
(
!
indices_and_skipgrams
.
empty
())
{
if
(
indices_and_skipgrams
.
front
().
first
+
max_skip_size_
+
1
<
token_i
)
{
indices_and_skipgrams
.
pop_front
();
}
else
{
break
;
}
}
// Check if we can form a valid skipgram prefix (or skipgram) by adding
// the current token to any of the existing skipgram prefixes, or
// if the current token is a valid skipgram prefix (or skipgram).
size_t
size
=
indices_and_skipgrams
.
size
();
for
(
size_t
skipgram_i
=
0
;
skipgram_i
<=
size
;
skipgram_i
++
)
{
const
auto
&
node
=
skipgram_i
<
size
?
*
indices_and_skipgrams
[
skipgram_i
].
second
:
skipgram_trie_
;
auto
iter
=
node
.
token_to_node
.
find
(
token
);
if
(
iter
!=
node
.
token_to_node
.
end
())
{
categories
.
insert
(
iter
->
second
.
categories
.
begin
(),
iter
->
second
.
categories
.
end
());
indices_and_skipgrams
.
push_back
(
std
::
make_pair
(
token_i
,
&
iter
->
second
));
}
for
(
auto
token_prefix
=
token_prefixes
.
rbegin
();
token_prefix
!=
token_prefixes
.
rend
();
token_prefix
++
)
{
auto
iter
=
node
.
prefix_to_node
.
find
(
*
token_prefix
);
if
(
iter
!=
node
.
prefix_to_node
.
end
())
{
categories
.
insert
(
iter
->
second
.
categories
.
begin
(),
iter
->
second
.
categories
.
end
());
indices_and_skipgrams
.
push_back
(
std
::
make_pair
(
token_i
,
&
iter
->
second
));
}
}
}
}
return
categories
;
}
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/skipgram_finder.h
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
namespace
seq_flow_lite
{
// SkipgramFinder finds skipgrams in strings.
//
// To use: First, add skipgrams using AddSkipgram() - each skipgram is
// associated with some category. Then, call FindSkipgrams() on a string,
// which will return the set of categories of the skipgrams in the string.
//
// Both the skipgrams and the input strings will be tokenzied by splitting
// on spaces. Additionally, the tokens will be lowercased and have any
// trailing punctuation removed.
class
SkipgramFinder
{
public:
explicit
SkipgramFinder
(
int
max_skip_size
)
:
max_skip_size_
(
max_skip_size
)
{}
// Adds a skipgram that SkipgramFinder should look for in input strings.
// Tokens may use the regex '.*' as a suffix.
void
AddSkipgram
(
absl
::
string_view
skipgram
,
int
category
);
// Find all of the skipgrams in `input`, and return their categories.
absl
::
flat_hash_set
<
int
>
FindSkipgrams
(
absl
::
string_view
input
)
const
;
// Find all of the skipgrams in `tokens`, and return their categories.
absl
::
flat_hash_set
<
int
>
FindSkipgrams
(
const
std
::
vector
<
absl
::
string_view
>&
tokens
)
const
;
private:
struct
TrieNode
{
absl
::
flat_hash_set
<
int
>
categories
;
// Maps tokens to the next node in the trie.
absl
::
flat_hash_map
<
std
::
string
,
TrieNode
>
token_to_node
;
// Maps token prefixes (<prefix>.*) to the next node in the trie.
absl
::
flat_hash_map
<
std
::
string
,
TrieNode
>
prefix_to_node
;
};
TrieNode
skipgram_trie_
;
int
max_skip_size_
;
};
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
research/seq_flow_lite/tf_ops/skipgram_finder_test.cc
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace
seq_flow_lite
{
namespace
{
using
::
testing
::
UnorderedElementsAreArray
;
void
TestFindSkipgrams
(
const
SkipgramFinder
&
skipgram_finder
,
const
std
::
vector
<
std
::
string
>&
tokens
,
const
std
::
vector
<
int
>&
categories
,
const
std
::
vector
<
int
>&
token_categories
)
{
EXPECT_THAT
(
skipgram_finder
.
FindSkipgrams
(
absl
::
StrJoin
(
tokens
,
" "
)),
UnorderedElementsAreArray
(
categories
));
std
::
vector
<
absl
::
string_view
>
sv_tokens
;
sv_tokens
.
reserve
(
tokens
.
size
());
for
(
const
auto
&
token
:
tokens
)
{
sv_tokens
.
emplace_back
(
token
.
data
(),
token
.
size
());
}
EXPECT_THAT
(
skipgram_finder
.
FindSkipgrams
(
sv_tokens
),
UnorderedElementsAreArray
(
token_categories
));
}
// Test that u_tolower() will only increase the number of bytes in the
// UTF-8 encoding in two specific cases.
TEST
(
SkipgramFinderTest
,
UCharToLower
)
{
for
(
UChar32
c
=
0
;
c
<
0x10000
;
c
++
)
{
if
(
c
==
0x23a
||
c
==
0x23e
)
continue
;
UChar32
l
=
u_tolower
(
c
);
EXPECT_GE
(
U8_LENGTH
(
c
),
U8_LENGTH
(
l
))
<<
c
<<
" lowercases to "
<<
l
;
}
}
TEST
(
SkipgramFinderTest
,
SingleExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"r"
,
"s"
,
"c"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"xyz"
,
"R!"
,
"xy"
,
"s"
,
"c"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"r"
,
"q"
,
"R"
,
"s."
,
"c"
},
{
0
},
{});
}
TEST
(
SkipgramFinderTest
,
SingleNotExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"x"
,
"x"
,
"r"
,
"x"
,
"s"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"x"
,
"r"
,
"x"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"r"
,
"x"
,
"s"
,
"q"
,
"c"
},
{},
{});
}
TEST
(
SkipgramFinderTest
,
SinglePrefixExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q.* r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"qa"
,
"r"
,
"s"
,
"c"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"xyz"
,
"R!"
,
"xy"
,
"s"
,
"c"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"qc"
,
"r"
,
"qd"
,
"R"
,
"s."
,
"c"
},
{
0
},
{});
}
TEST
(
SkipgramFinderTest
,
SinglePrefixNotExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q.* r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"aq"
,
"r"
,
"s"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"aqc"
,
"xyz"
,
"R!"
,
"xy"
,
"s"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"ar"
,
"q"
,
"aR"
,
"s."
,
"c"
},
{},
{});
}
TEST
(
SkipgramFinderTest
,
Punctuation
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"a-b-c def"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"q"
,
"abc"
,
"q"
,
"d-e-f"
,
"q"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"'abc'"
,
"q"
,
"'def'"
,
"q"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"q"
,
"abc"
,
"q"
,
"def"
,
"q"
},
{
0
},
{
0
});
}
TEST
(
SkipgramFinderTest
,
HandlesMultibyteInput
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"hello
\363\243\243\243
!"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
}
TEST
(
SkipgramFinderTest
,
Multiple
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s1
(
"a b c"
);
std
::
string
s2
(
"D e. F!"
);
std
::
string
s3
(
"ghi jkl mno"
);
std
::
string
s4
(
"S T U"
);
std
::
string
s5
(
"x. y, z!"
);
std
::
string
s6
(
"d.* e f"
);
skipgram_finder
.
AddSkipgram
(
s1
,
0
);
skipgram_finder
.
AddSkipgram
(
s2
,
2
);
skipgram_finder
.
AddSkipgram
(
s3
,
4
);
skipgram_finder
.
AddSkipgram
(
s4
,
6
);
skipgram_finder
.
AddSkipgram
(
s5
,
8
);
skipgram_finder
.
AddSkipgram
(
s6
,
10
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"d"
,
"b"
,
"e"
,
"c"
,
"f"
},
{
0
,
2
,
10
},
{
0
,
2
,
10
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"dq"
,
"b"
,
"e"
,
"c"
,
"f"
},
{
0
,
10
},
{
0
,
10
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"d"
,
"b"
,
"eq"
,
"c"
,
"f"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"ghi"
,
"b"
,
"jkl"
,
"c"
,
"x"
,
"mno"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"ghi"
,
"d"
,
"jkl"
,
"e"
,
"mno"
,
"f"
},
{
2
,
4
,
10
},
{
2
,
4
,
10
});
TestFindSkipgrams
(
skipgram_finder
,
{
"s"
,
"x"
,
"t"
,
"y"
,
"u"
,
"z"
},
{
6
,
8
},
{
6
,
8
});
}
TEST
(
SkipgramFinderTest
,
UnicodeLowercase
)
{
// Check that the lowercase has a smaller UTF-8 encoding than the uppercase.
UChar32
cu
;
U8_GET_UNSAFE
(
"Ɦ"
,
0
,
cu
);
UChar32
cl
=
u_tolower
(
cu
);
EXPECT_GT
(
U8_LENGTH
(
cu
),
U8_LENGTH
(
cl
));
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"Ɦ"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"Ɦ"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"ɦ"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"h"
},
{},
{});
}
}
// namespace
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/subsequence_finder.cc
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace
seq_flow_lite
{
void
SubsequenceFinder
::
AddSubsequence
(
absl
::
string_view
subsequence
,
int
category
)
{
const
char
*
s
=
subsequence
.
data
();
int32_t
length
=
subsequence
.
length
();
int32_t
n
=
0
;
TrieNode
*
trie
=
&
subsequence_trie_
;
bool
new_word
=
true
;
while
(
n
<
length
)
{
UChar32
c
;
U8_NEXT
(
s
,
n
,
length
,
c
);
if
(
c
<
0
)
return
;
c
=
u_tolower
(
c
);
if
(
c
==
' '
)
{
new_word
=
true
;
}
else
if
(
!
new_word
)
{
trie
=
&
trie
->
continue_token
[
c
];
}
else
{
trie
=
&
trie
->
next_token
[
c
];
new_word
=
false
;
}
}
trie
->
categories
.
insert
(
category
);
}
// Given a UChar32 and a trie node representing an in-progress subsequence,
// determine if we can use the UChar32 to continue the subsequence, and
// update `categories`, `next_tokens`, and `continue_tokens` if needed.
void
SubsequenceFinder
::
ProcessUChar32AndTrieNode
(
int
index
,
UChar32
c
,
const
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>&
token_map
,
absl
::
flat_hash_set
<
int
>*
categories
,
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>*
next_tokens
,
std
::
vector
<
const
TrieNode
*>*
continue_tokens
)
const
{
auto
iter
=
token_map
.
find
(
c
);
if
(
iter
!=
token_map
.
end
())
{
categories
->
insert
(
iter
->
second
.
categories
.
begin
(),
iter
->
second
.
categories
.
end
());
if
(
!
iter
->
second
.
continue_token
.
empty
())
{
continue_tokens
->
push_back
(
&
iter
->
second
);
}
if
(
!
iter
->
second
.
next_token
.
empty
())
{
next_tokens
->
emplace_back
(
index
,
&
iter
->
second
);
}
}
}
absl
::
flat_hash_set
<
int
>
SubsequenceFinder
::
FindSubsequences
(
absl
::
string_view
input
)
const
{
absl
::
flat_hash_set
<
int
>
categories
;
// Tracks subsequences in progress that are starting the next token,
// as well as the index of their last character.
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>
next_tokens
;
// Tracks subsequences in progress that are looking for the next character
// in their corrent token. `current_continue_tokens` is the current set of
// subsequences being processed, while `future_continue_tokens` is the set
// of subsequences to process for the next character.
std
::
vector
<
const
TrieNode
*>
current_continue_tokens
;
std
::
vector
<
const
TrieNode
*>
future_continue_tokens
;
const
char
*
s
=
input
.
data
();
int32_t
length
=
input
.
length
();
int32_t
n
=
0
;
int
index
=
0
;
while
(
n
<
length
)
{
UChar32
c
;
U8_NEXT
(
s
,
n
,
length
,
c
);
if
(
c
<
0
)
return
categories
;
c
=
u_tolower
(
c
);
// Drop any subsequences which would need to skip more than `max_skip_size_`
// characters between the end of their last token and the current character.
while
(
!
next_tokens
.
empty
())
{
if
(
next_tokens
.
front
().
first
+
max_skip_size_
+
1
<
index
)
{
next_tokens
.
pop_front
();
}
else
{
break
;
}
}
// Check subsequences starting a new token.
size_t
size
=
next_tokens
.
size
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
ProcessUChar32AndTrieNode
(
index
,
c
,
next_tokens
[
i
].
second
->
next_token
,
&
categories
,
&
next_tokens
,
&
future_continue_tokens
);
}
// Check subsequences continuing a token.
for
(
const
TrieNode
*
continue_token
:
current_continue_tokens
)
{
ProcessUChar32AndTrieNode
(
index
,
c
,
continue_token
->
continue_token
,
&
categories
,
&
next_tokens
,
&
future_continue_tokens
);
}
// Check if we can start a new subsequence.
ProcessUChar32AndTrieNode
(
index
,
c
,
subsequence_trie_
.
next_token
,
&
categories
,
&
next_tokens
,
&
future_continue_tokens
);
current_continue_tokens
.
swap
(
future_continue_tokens
);
future_continue_tokens
.
clear
();
index
++
;
}
return
categories
;
}
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/subsequence_finder.h
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
namespace
seq_flow_lite
{
// SubsequenceFinder finds subsequences in UTF-8 strings.
//
// Specifically, given a subsequence t_1 t_2 ... t_n, we will check if a
// string matches '.*t_1.{0,N}t_2.{0,N} ... .{0,N}t_n.*', where N is the
// maximum skip size.
//
// To use: First, add subsequences using AddSubsequence() - each subsequence
// is associated with some category. Then call FindSubsequences() on a string,
// which will return the set of categories of the subsesequences in the string.
//
// The subsequences will be tokenized by splitting on spaces. Both subsequences
// and input strings will be normalized by lowercasing.
class
SubsequenceFinder
{
public:
explicit
SubsequenceFinder
(
int
max_skip_size
)
:
max_skip_size_
(
max_skip_size
)
{}
// Adds a subsequence that SubsequenceFinder should look for in input strings.
void
AddSubsequence
(
absl
::
string_view
subsequence
,
int
category
);
// Find all of the subsequences in `input`, and return their categories.
absl
::
flat_hash_set
<
int
>
FindSubsequences
(
absl
::
string_view
input
)
const
;
private:
// This trie tracks the next character needed to:
// * continue the current token
// * start the next token
struct
TrieNode
{
absl
::
flat_hash_set
<
int
>
categories
;
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>
continue_token
;
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>
next_token
;
};
void
ProcessUChar32AndTrieNode
(
int
index
,
UChar32
c
,
const
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>&
token_map
,
absl
::
flat_hash_set
<
int
>*
categories
,
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>*
next_tokens
,
std
::
vector
<
const
TrieNode
*>*
continue_tokens
)
const
;
TrieNode
subsequence_trie_
;
int
max_skip_size_
;
};
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
research/seq_flow_lite/tf_ops/subsequence_finder_test.cc
0 → 100644
浏览文件 @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace
seq_flow_lite
{
namespace
{
using
::
testing
::
UnorderedElementsAre
;
TEST
(
SubsequenceFinderTest
,
SingleExists
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"ab cd"
,
0
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"abcd"
),
UnorderedElementsAre
(
0
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"ab012cd"
),
UnorderedElementsAre
(
0
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"AB CD"
),
UnorderedElementsAre
(
0
));
}
TEST
(
SubsequenceFinderTest
,
SingleNotExists
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"ab cd"
,
0
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"a bcd"
),
UnorderedElementsAre
());
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"ab0123cd"
),
UnorderedElementsAre
());
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"abdc"
),
UnorderedElementsAre
());
}
TEST
(
SubsequenceFinderTest
,
Multiple
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"a b c d"
,
0
);
subsequence_finder
.
AddSubsequence
(
"q r s"
,
2
);
subsequence_finder
.
AddSubsequence
(
"b c d e"
,
4
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"a__b__c__d__e"
),
UnorderedElementsAre
(
0
,
4
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"aqbrcsd"
),
UnorderedElementsAre
(
0
,
2
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"b q c r d s e"
),
UnorderedElementsAre
(
2
,
4
));
}
TEST
(
SubsequenceFinderTest
,
Utf8
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"一二 三四 五六"
,
0
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"一二おはよ三四こんに五六"
),
UnorderedElementsAre
(
0
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"一二三 四五六"
),
UnorderedElementsAre
());
}
}
// namespace
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/text_distorter.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_TEXT_DISTORTER_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_TEXT_DISTORTER_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_TEXT_DISTORTER_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_TEXT_DISTORTER_H_
#include <assert.h>
...
...
@@ -40,4 +40,4 @@ class TextDistorter {
UChar32
random_char_
=
0
;
};
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_TEXT_DISTORTER_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_TEXT_DISTORTER_H_
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
浏览文件 @
20cc2190
...
...
@@ -122,4 +122,4 @@ REGISTER_OP("UniformCausalAttn")
})
.
Doc
(
R"doc(
Dummy uniform causal attn op.
)doc"
;
)doc"
)
;
research/seq_flow_lite/tflite_ops/BUILD
浏览文件 @
20cc2190
...
...
@@ -121,9 +121,9 @@ cc_library(
hdrs
=
[
"tflite_qrnn_pooling.h"
],
copts
=
tflite_copts
(),
deps
=
[
"
//third_party/absl/base:core_header
s"
,
"//t
hird_party/tensorflow/lite/kernels:builtin_ops"
,
"
//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util
"
,
"
@org_tensorflow//tensorflow/lite/kernels:builtin_op
s"
,
"//t
flite_ops:quantization_util"
,
# sequence projection
"
@com_google_absl//absl/base:core_headers
"
,
],
alwayslink
=
1
,
)
...
...
@@ -132,7 +132,7 @@ cc_library(
name
=
"tflite_decoder_cache"
,
hdrs
=
[
"tflite_decoder_cache.h"
],
deps
=
[
"
//third_party
/tensorflow/lite/c:common"
,
"
@org_tensorflow/
/tensorflow/lite/c:common"
,
],
alwayslink
=
1
,
)
...
...
@@ -144,12 +144,12 @@ cc_library(
copts
=
tflite_copts
(),
deps
=
[
":tflite_decoder_cache"
,
"
//third_party/flatbuffers
"
,
"
//third_party/tensorflow/lite/c:common
"
,
"
//third_party/tensorflow/lite/kernels:builtin_ops
"
,
"
//third_party/tensorflow/lite/kernels:kernel_util
"
,
"//t
hird_party/tensorflow/lite/kernels/internal:tensor"
,
"
//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util
"
,
"
@org_tensorflow//tensorflow/lite/c:common
"
,
"
@org_tensorflow//tensorflow/lite/kernels:builtin_ops
"
,
"
@org_tensorflow//tensorflow/lite/kernels:kernel_util
"
,
"
@org_tensorflow//tensorflow/lite/kernels/internal:tensor
"
,
"//t
flite_ops:quantization_util"
,
# sequence projection
"
@flatbuffers
"
,
],
alwayslink
=
1
,
)
...
...
@@ -160,11 +160,11 @@ cc_test(
srcs
=
[
"tflite_decoder_handler_test.cc"
],
deps
=
[
":tflite_decoder_handler"
,
"
//testing/base/public:gunit
"
,
"
//third_party/flatbuffers
"
,
"
//third_party/tensorflow/lite:framework
"
,
"
//third_party/tensorflow/lite/c:common
"
,
"
//third_party/tensorflow/lite/kernels:test_util
"
,
"
@org_tensorflow//tensorflow/lite:framework
"
,
"
@org_tensorflow//tensorflow/lite/c:common
"
,
"
@org_tensorflow//tensorflow/lite/kernels:test_util
"
,
"
@com_google_googletest//:gtest
"
,
"
@flatbuffers
"
,
],
)
...
...
@@ -176,10 +176,10 @@ cc_library(
deps
=
[
"//base"
,
"//third_party/absl/strings"
,
"
//third_party
/tensorflow/lite/c:common"
,
"
//third_party
/tensorflow/lite/kernels/internal:tensor"
,
"
//third_party
/tensorflow/lite/kernels/internal:types"
,
"//t
hird_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util"
,
"
@org_tensorflow/
/tensorflow/lite/c:common"
,
"
@org_tensorflow/
/tensorflow/lite/kernels/internal:tensor"
,
"
@org_tensorflow/
/tensorflow/lite/kernels/internal:types"
,
"//t
flite_ops:quantization_util"
,
# sequence projection
],
)
...
...
@@ -189,14 +189,14 @@ cc_test(
copts
=
tflite_copts
(),
deps
=
[
":beam_search"
,
"//testing/base/public:gunit_main"
,
"//third_party/absl/strings"
,
"//third_party/tensorflow/lite/c:c_api_types"
,
"//third_party/tensorflow/lite/c:common"
,
"//third_party/tensorflow/lite/kernels/internal:legacy_reference_base"
,
"//third_party/tensorflow/lite/kernels/internal:optimized_base"
,
"//third_party/tensorflow/lite/kernels/internal:tensor"
,
"//third_party/tensorflow/lite/kernels/internal:types"
,
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util"
,
"@org_tensorflow//tensorflow/lite/c:c_api_types"
,
"@org_tensorflow//tensorflow/lite/c:common"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:legacy_reference_base"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:optimized_base"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:types"
,
"//tflite_ops:quantization_util"
,
# sequence projection
"@com_google_googletest//:gtest_main"
,
],
)
research/seq_flow_lite/tflite_ops/beam_search.cc
浏览文件 @
20cc2190
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "t
hird_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include "t
flite_ops/beam_search.h" // seq_flow_lite
#include <algorithm>
#include <cstdint>
...
...
@@ -21,10 +21,10 @@ limitations under the License.
#include <vector>
#include "base/logging.h"
#include "
third_party/
absl/strings/str_join.h"
#include "t
hird_party/t
ensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "t
hird_party/t
ensorflow/lite/kernels/internal/types.h"
#include "t
hird_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "t
flite_ops/quantization_util.h" // seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
...
...
@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
std
::
vector
<
std
::
vector
<
int32_t
>>
SequenceTracker
::
GetTopBeams
()
{
std
::
vector
<
std
::
vector
<
int32_t
>>
return_value
;
return_value
.
reserve
(
terminated_topk_
.
size
());
for
(
const
auto
&
v
:
terminated_topk_
)
{
return_value
.
push_back
(
v
.
second
);
}
...
...
@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
}
}
// Updating topk across all beams.
for
(
uint32_t
k
=
0
;
k
<
std
::
min
(
topk_k
,
num_classes_
);
++
k
)
{
const
uint32_t
curr_beam_index
=
curr_beam
_topk
[
k
]
&
kClassIndexMask
;
for
(
uint32_t
curr_beam
:
curr_beam_top
k
)
{
const
uint32_t
curr_beam_index
=
curr_beam
&
kClassIndexMask
;
const
uint32_t
index
=
j
*
num_classes_
+
curr_beam_index
;
const
float
log_prob
=
tensor
.
params
.
scale
*
beam_logits
[
curr_beam_index
]
-
precomputed
;
...
...
research/seq_flow_lite/tflite_ops/beam_search.h
浏览文件 @
20cc2190
...
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#include <cstdint>
#include <functional>
...
...
@@ -23,7 +23,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "t
hird_party/t
ensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace
seq_flow_lite
{
namespace
ops
{
...
...
@@ -110,4 +110,4 @@ class BeamSearch {
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif // T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
research/seq_flow_lite/tflite_ops/beam_search_test.cc
浏览文件 @
20cc2190
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "t
hird_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include "t
flite_ops/beam_search.h" // seq_flow_lite
#include <cstdint>
#include <functional>
...
...
@@ -21,17 +21,17 @@ limitations under the License.
#include <memory>
#include <vector>
#include
"testing/base/public/gmock.h"
#include
"testing/base/public/gunit.h"
#include "
third_party/
absl/strings/str_join.h"
#include "t
hird_party/t
ensorflow/lite/c/c_api_types.h"
#include "t
hird_party/t
ensorflow/lite/c/common.h"
#include "t
hird_party/t
ensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "t
hird_party/t
ensorflow/lite/kernels/internal/reference/dequantize.h"
#include "t
hird_party/t
ensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "t
hird_party/t
ensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "t
hird_party/t
ensorflow/lite/kernels/internal/types.h"
#include "t
hird_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include
<gmock/gmock.h>
#include
<gtest/gtest.h>
#include "absl/strings/str_join.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "t
flite_ops/quantization_util.h" // seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
...
...
@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch {
cur_cache
+
(
selected_beams
[
beam
]
*
NumClasses
());
for
(
int
j
=
0
;
j
<
NumClasses
();
++
j
,
index
++
)
{
next_cache
[
index
]
=
(
selected
[
j
]
+
next_cache
[
index
])
/
2
;
data_ptr
[
index
]
=
::
seq_flow_lite
::
PodQuantize
(
data_ptr
[
index
]
=
PodQuantize
(
next_cache
[
index
],
decoder_output_
->
params
.
zero_point
,
1.0
f
/
decoder_output_
->
params
.
scale
);
}
...
...
research/seq_flow_lite/tflite_ops/expected_value.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
#include "tensorflow/lite/kernels/register.h"
...
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
}
// namespace ops
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
research/seq_flow_lite/tflite_ops/layer_norm.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef
LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLER
S_LAYER_NORM_H_
#define
LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLER
S_LAYER_NORM_H_
#ifndef
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OP
S_LAYER_NORM_H_
#define
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OP
S_LAYER_NORM_H_
#include "tensorflow/lite/kernels/register.h"
...
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM();
}
// namespace ops
}
// namespace seq_flow_lite
#endif //
LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLER
S_LAYER_NORM_H_
#endif //
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OP
S_LAYER_NORM_H_
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
浏览文件 @
20cc2190
...
...
@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) {
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
-
3
,
-
1
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
/*input_max=*/
1.0
,
/*output_min=*/
-
3.0
,
/*output_max=*/
3.0
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
research/seq_flow_lite/tflite_ops/quantization_util.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm>
#include <cmath>
...
...
@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
浏览文件 @
20cc2190
...
...
@@ -101,7 +101,7 @@ class ProjectionParams {
bool
exclude_nonalphaspace_unicodes
,
const
std
::
string
&
token_separators
,
bool
normalize_repetition
,
bool
add_first_cap_feature
,
bool
add_all_caps_feature
)
bool
add_all_caps_feature
,
bool
normalize_spaces
)
:
feature_size_
(
feature_size
),
unicode_handler_
(
vocabulary
,
exclude_nonalphaspace_unicodes
),
hasher_
(
Hasher
::
CreateHasher
(
feature_size
,
hashtype
)),
...
...
@@ -130,9 +130,9 @@ class ProjectionParams {
}
word_novelty_offset_
=
2.0
f
/
(
1
<<
word_novelty_bits_
);
if
(
!
token_separators
.
empty
()
||
normalize_repetition
)
{
if
(
!
token_separators
.
empty
()
||
normalize_repetition
||
normalize_spaces
)
{
projection_normalizer_
=
std
::
make_unique
<
ProjectionNormalizer
>
(
token_separators
,
normalize_repetition
);
token_separators
,
normalize_repetition
,
normalize_spaces
);
}
}
virtual
~
ProjectionParams
()
{}
...
...
@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams {
/*exclude_nonalphaspace_unicodes = */
false
,
/*token_separators = */
""
,
normalize_repetition
,
/*add_first_cap_feature = */
false
,
/*add_all_caps_feature = */
false
)
{}
/*add_all_caps_feature = */
false
,
/*normalize_spaces = */
false
)
{}
~
ProjectionParamsV2
()
override
{}
TfLiteStatus
PreprocessInput
(
TfLiteTensor
*
input_t
,
...
...
@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const
std
::
string
token_separators
=
m
[
"token_separators"
].
IsNull
()
?
""
:
m
[
"token_separators"
].
ToString
();
const
bool
normalize_repetition
=
m
[
"normalize_repetition"
].
AsBool
();
const
bool
normalize_spaces
=
m
[
"normalize_spaces"
].
AsBool
();
if
(
!
Hasher
::
SupportedHashType
(
hashtype
))
{
context
->
ReportError
(
context
,
"Unsupported hashtype %s
\n
"
,
hashtype
.
c_str
());
...
...
@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
add_bos_tag
?
BosTag
::
kGenerate
:
BosTag
::
kNone
,
add_eos_tag
?
EosTag
::
kGenerate
:
EosTag
::
kNone
,
exclude_nonalphaspace_unicodes
,
token_separators
,
normalize_repetition
,
add_first_cap_feature
==
1.0
f
,
add_all_caps_feature
==
1.0
f
);
add_first_cap_feature
==
1.0
f
,
add_all_caps_feature
==
1.0
f
,
normalize_spaces
);
}
void
*
InitV2
(
TfLiteContext
*
context
,
const
char
*
buffer
,
size_t
length
)
{
...
...
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
浏览文件 @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
...
...
@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
extern
const
char
kSequenceStringProjectionV2
[];
TfLiteRegistration
*
Register_SEQUENCE_STRING_PROJECTION_V2
();
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
浏览文件 @
20cc2190
...
...
@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase;
using
::
seq_flow_lite
::
testing
::
StringTensor
;
using
::
seq_flow_lite
::
testing
::
TensorflowTfLiteOpTest
;
using
::
testing
::
ElementsAreArray
;
using
::
testing
::
Not
;
using
::
tflite
::
TensorType_FLOAT32
;
using
::
tflite
::
TensorType_STRING
;
using
::
tflite
::
TensorType_UINT8
;
...
...
@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
int
doc_size_levels
,
bool
add_eos_tag
,
::
tflite
::
TensorType
output_type
,
const
std
::
string
&
token_separators
=
""
,
bool
normalize_repetition
=
false
,
float
add_first_cap
=
0.0
,
float
add_all_caps
=
0.0
,
const
std
::
string
&
hashtype
=
kMurmurHash
)
{
float
add_all_caps
=
0.0
,
const
std
::
string
&
hashtype
=
kMurmurHash
,
bool
normalize_spaces
=
false
)
{
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
...
...
@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
fbb
.
Bool
(
"normalize_repetition"
,
normalize_repetition
);
fbb
.
Float
(
"add_first_cap_feature"
,
add_first_cap
);
fbb
.
Float
(
"add_all_caps_feature"
,
add_all_caps
);
fbb
.
Bool
(
"normalize_spaces"
,
normalize_spaces
);
});
fbb
.
Finish
();
output_
=
AddOutput
({
output_type
,
{}});
...
...
@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
SingleOpModel
::
Invoke
(
);
CHECK_EQ
(
SingleOpModel
::
Invoke
(),
kTfLiteOk
);
}
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
Invoke
Unchecked
();
return
SingleOpModel
::
Invoke
();
}
template
<
typename
T
>
...
...
@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) {
EXPECT_THAT
(
output1
,
ElementsAreArray
(
output2
));
}
TEST
(
SequenceStringProjectionTest
,
NormalizeSpaces
)
{
SequenceStringProjectionModel
model_nonormalize
(
false
,
-
1
,
0
,
0
,
false
,
TensorType_UINT8
,
""
,
false
,
0.0
,
0.0
,
kMurmurHash
,
false
);
SequenceStringProjectionModel
model_normalize
(
false
,
-
1
,
0
,
0
,
false
,
TensorType_UINT8
,
""
,
false
,
0.0
,
0.0
,
kMurmurHash
,
true
);
const
char
kNoExtraSpaces
[]
=
"Hello there."
;
const
char
kExtraSpaces
[]
=
" Hello there. "
;
model_nonormalize
.
Invoke
(
kNoExtraSpaces
);
auto
output_noextra_nonorm
=
model_nonormalize
.
GetOutput
<
uint8_t
>
();
model_nonormalize
.
Invoke
(
kExtraSpaces
);
auto
output_extra_nonorm
=
model_nonormalize
.
GetOutput
<
uint8_t
>
();
model_normalize
.
Invoke
(
kNoExtraSpaces
);
auto
output_noextra_norm
=
model_normalize
.
GetOutput
<
uint8_t
>
();
model_normalize
.
Invoke
(
kExtraSpaces
);
auto
output_extra_norm
=
model_normalize
.
GetOutput
<
uint8_t
>
();
EXPECT_THAT
(
output_noextra_nonorm
,
ElementsAreArray
(
output_noextra_norm
));
EXPECT_THAT
(
output_noextra_nonorm
,
ElementsAreArray
(
output_extra_norm
));
EXPECT_THAT
(
output_noextra_nonorm
,
Not
(
ElementsAreArray
(
output_extra_nonorm
)));
}
class
SequenceStringProjectionTest
:
public
TensorflowTfLiteOpTest
{
std
::
function
<
TfLiteRegistration
*
()
>
TfLiteOpRegistration
()
override
{
return
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION
;
...
...
@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"NormalizeRepetition"
;
...
...
@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"NormalizeSpaces"
;
test_case
.
attributes
[
"vocabulary"
]
=
AttrValue
(
""
);
test_case
.
attributes
[
"split_on_space"
]
=
AttrValue
(
true
);
test_case
.
attributes
[
"feature_size"
]
=
AttrValue
(
8
);
test_case
.
attributes
[
"add_eos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_bos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"normalize_spaces"
]
=
AttrValue
(
true
);
test_case
.
input_tensors
.
push_back
(
StringTensor
({
1
},
{
" Hello there. "
}));
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
return
test_cases
;
}
...
...
@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel {
PopulateStringTensor
(
input_
,
input
);
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
ASSERT_EQ
(
SingleOpModel
::
Invoke
Unchecked
(),
expected
);
ASSERT_EQ
(
SingleOpModel
::
Invoke
(),
expected
);
}
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
Invoke
Unchecked
();
return
SingleOpModel
::
Invoke
();
}
std
::
vector
<
int
>
GetOutputShape
()
{
return
GetTensorShape
(
output_
);
}
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
浏览文件 @
20cc2190
...
...
@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() {
input_index
++
;
}
tflite_op_
.
Invoke
(
);
ASSERT_EQ
(
tflite_op_
.
Invoke
(),
kTfLiteOk
);
}
void
TensorflowTfLiteOpTest
::
CompareOpOutput
()
{
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
浏览文件 @
20cc2190
...
...
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Tests equivalence between TF and TFLite versions of an op.
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#include <string>
#include <vector>
...
...
@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest
}
// namespace testing
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
浏览文件 @
20cc2190
...
...
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#include <memory>
#include "third_party/tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -113,4 +114,4 @@ class DynamicCacheOp {
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif // T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
浏览文件 @
20cc2190
...
...
@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "t
hird_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include "t
flite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_cache.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_decoder_cache.h" // seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
...
...
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
浏览文件 @
20cc2190
...
...
@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#include "t
hird_party/t
ensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
TfLiteRegistration
*
Register_UNIFORM_CAUSAL_ATTENTION
();
}
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif // T
HIRD_PARTY_T
ENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
浏览文件 @
20cc2190
...
...
@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "t
hird_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include "t
flite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint>
#include <cstdlib>
#include <vector>
#include
"testing/base/public/gmock.h"
#include
"testing/base/public/gunit.h"
#include "
third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "t
hird_party/t
ensorflow/lite/c/common.h"
#include "t
hird_party/t
ensorflow/lite/kernels/test_util.h"
#include
<gmock/gmock.h>
#include
<gtest/gtest.h>
#include "
flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace
{
...
...
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
浏览文件 @
20cc2190
...
...
@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_qrnn_pooling.h" // seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
namespace
{
...
...
@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
return
QRNNPooling
(
context
,
multiplier
,
constant
,
outputs
,
final_state
,
(
direction
->
data
.
uint8
[
0
]
==
kPoolingForward
));
}
}
// namespace
namespace
custom
{
const
char
kPoolingOp
[]
=
"PoolingOp"
;
void
RegisterQRNNPooling
(
::
tflite
::
ops
::
builtin
::
BuiltinOpResolver
*
resolver
)
{
...
...
@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() {
}
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
浏览文件 @
20cc2190
...
...
@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#include "
third_party/
absl/base/macros.h"
#include "t
hird_party/t
ensorflow/lite/kernels/register.h"
#include "absl/base/macros.h"
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
extern
const
char
kPoolingOp
[];
...
...
@@ -27,7 +27,7 @@ extern const char kPoolingOp[];
TfLiteRegistration
*
Register_QRNN_POOLING
();
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
research/seq_flow_lite/trainer.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A utility for PRADO model to do train, eval, inference and model export."""
import
importlib
...
...
@@ -22,6 +21,7 @@ from absl import app
from
absl
import
flags
from
absl
import
logging
import
tensorflow.compat.v1
as
tf
from
tensorflow.compat.v1
import
estimator
as
tf_estimator
import
input_fn_reader
# import root module
import
metric_functions
# import root module
...
...
@@ -48,14 +48,14 @@ def load_runner_config():
return
json
.
loads
(
f
.
read
())
def
create_model
(
model
,
model_config
,
features
,
mode
):
def
create_model
(
model
,
model_config
,
features
,
mode
,
model_name
):
"""Creates a sequence labeling model."""
keras_model
=
model
.
Encoder
(
model_config
,
mode
)
if
"pqrnn"
in
model_name
:
logits
=
keras_model
(
features
[
"projection"
],
features
[
"seq_length"
])
else
:
logits
=
keras_model
(
features
[
"token_ids"
],
features
[
"token_len"
])
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
features
[
"label"
],
logits
=
logits
)
...
...
@@ -94,33 +94,33 @@ def model_fn_builder(runner_config):
def
model_fn
(
features
,
mode
,
params
):
"""The `model_fn` for TPUEstimator."""
label_ids
=
None
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
label_ids
=
features
[
"label"
]
model_config
=
runner_config
[
"model_config"
]
loss
,
logits
=
create_model
(
model
,
model_config
,
features
,
mode
)
loss
,
logits
=
create_model
(
model
,
model_config
,
features
,
mode
,
runner_config
[
"name"
])
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
if
mode
==
tf
_
estimator
.
ModeKeys
.
TRAIN
:
train_op
=
create_optimizer
(
loss
,
runner_config
,
params
)
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
_
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
elif
mode
==
tf
_
estimator
.
ModeKeys
.
EVAL
:
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
metric_fn
=
metric_functions
.
classification_metric
else
:
metric_fn
=
metric_functions
.
labeling_metric
eval_metrics
=
(
metric_fn
,
[
loss
,
label_ids
,
logits
])
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
_
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
eval_metrics
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
elif
mode
==
tf
_
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
{
"logits"
:
logits
}
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
predictions
[
"predictions"
]
=
tf
.
nn
.
softmax
(
logits
)
else
:
predictions
[
"predictions"
]
=
tf
.
math
.
sigmoid
(
logits
)
return
tf
.
compat
.
v1
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
return
tf_estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
else
:
assert
False
,
"Expected to be called in TRAIN, EVAL, or PREDICT mode."
...
...
@@ -133,13 +133,13 @@ def main(_):
if
FLAGS
.
output_dir
:
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
is_per_host
=
tf
.
estimator
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
estimator
.
tpu
.
RunConfig
(
is_per_host
=
tf
_
estimator
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
_
estimator
.
tpu
.
RunConfig
(
master
=
FLAGS
.
master
,
model_dir
=
FLAGS
.
output_dir
,
save_checkpoints_steps
=
runner_config
[
"save_checkpoints_steps"
],
keep_checkpoint_max
=
20
,
tpu_config
=
tf
.
estimator
.
tpu
.
TPUConfig
(
tpu_config
=
tf
_
estimator
.
tpu
.
TPUConfig
(
iterations_per_loop
=
runner_config
[
"iterations_per_loop"
],
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
...
...
@@ -149,7 +149,7 @@ def main(_):
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
batch_size
=
runner_config
[
"batch_size"
]
estimator
=
tf
.
estimator
.
tpu
.
TPUEstimator
(
estimator
=
tf
_
estimator
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
config
=
run_config
,
...
...
@@ -160,7 +160,7 @@ def main(_):
if
FLAGS
.
runner_mode
==
"train"
:
train_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
mode
=
tf
_
estimator
.
ModeKeys
.
TRAIN
,
drop_remainder
=
True
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
runner_config
[
"train_steps"
])
...
...
@@ -168,7 +168,7 @@ def main(_):
# TPU needs fixed shapes, so if the last batch is smaller, we drop it.
eval_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
,
mode
=
tf
_
estimator
.
ModeKeys
.
EVAL
,
drop_remainder
=
True
)
for
_
in
tf
.
train
.
checkpoints_iterator
(
FLAGS
.
output_dir
,
timeout
=
600
):
...
...
research/seq_flow_lite/trainer_v2.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Binary to train PRADO model with TF 2.0."""
import
importlib
...
...
@@ -23,6 +22,7 @@ from absl import flags
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow
import
estimator
as
tf_estimator
import
input_fn_reader
# import root module
...
...
@@ -48,7 +48,7 @@ def load_runner_config():
def
compute_loss
(
logits
,
labels
,
model_config
,
mode
):
"""Creates a sequence labeling model."""
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
logits
)
...
...
@@ -77,11 +77,11 @@ def main(_):
if
FLAGS
.
output_dir
:
tf
.
io
.
gfile
.
makedirs
(
FLAGS
.
output_dir
)
train_model
=
model_fn_builder
(
runner_config
,
tf
.
estimator
.
ModeKeys
.
TRAIN
)
train_model
=
model_fn_builder
(
runner_config
,
tf
_
estimator
.
ModeKeys
.
TRAIN
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
()
train_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
mode
=
tf
_
estimator
.
ModeKeys
.
TRAIN
,
drop_remainder
=
True
)
params
=
{
"batch_size"
:
runner_config
[
"batch_size"
]}
train_ds
=
train_input_fn
(
params
)
...
...
@@ -93,7 +93,7 @@ def main(_):
logits
=
train_model
(
features
[
"projection"
],
features
[
"seq_length"
])
loss
=
compute_loss
(
logits
,
features
[
"label"
],
runner_config
[
"model_config"
],
tf
.
estimator
.
ModeKeys
.
TRAIN
)
tf
_
estimator
.
ModeKeys
.
TRAIN
)
gradients
=
tape
.
gradient
(
loss
,
train_model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
gradients
,
train_model
.
trainable_variables
))
train_loss
(
loss
)
...
...
research/seq_flow_lite/utils/misc_utils.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A module for miscelaneous utils."""
import
tensorflow
as
tf
...
...
research/seq_flow_lite/utils/tflite_utils.py
浏览文件 @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Utils to convert to a TFLite model."""
import
tensorflow.compat.v1
as
tf
...
...
@@ -65,9 +64,14 @@ def get_mean_stddev_values(min_value_of_features, max_value_of_features):
class
InterpreterWithCustomOps
(
tf
.
lite
.
Interpreter
):
"""Extended tf.lite.Interpreter."""
def
__init__
(
self
,
model_content
,
custom_op_registerers
=
None
):
def
__init__
(
self
,
model_content
,
custom_op_registerers
=
None
,
experimental_preserve_all_tensors
=
False
):
self
.
_custom_op_registerers
=
custom_op_registerers
or
[]
super
(
InterpreterWithCustomOps
,
self
).
__init__
(
model_content
=
model_content
)
super
(
InterpreterWithCustomOps
,
self
).
__init__
(
model_content
=
model_content
,
experimental_preserve_all_tensors
=
experimental_preserve_all_tensors
)
def
op_details
(
self
):
op_details
=
{}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录