提交 d7eabefa 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

[translation] Add text2text export module.

PiperOrigin-RevId: 418559537
上级 439d515a
......@@ -13,12 +13,14 @@
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import dataclasses
import os
from typing import Any, Dict, Text
from absl import app
from absl import flags
import dataclasses
import yaml
from official.core import base_task
from official.core import task_factory
from official.modeling import hyperparams
......@@ -29,6 +31,7 @@ from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
from official.nlp.tasks import translation
FLAGS = flags.FLAGS
......@@ -40,7 +43,9 @@ SERVING_MODULES = {
question_answering.QuestionAnsweringTask:
serving_modules.QuestionAnswering,
tagging.TaggingTask:
serving_modules.Tagging
serving_modules.Tagging,
translation.TranslationTask:
serving_modules.Translation
}
......
......@@ -14,10 +14,12 @@
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
import dataclasses
from typing import Dict, List, Optional, Text
import dataclasses
import tensorflow as tf
import tensorflow_text as tf_text
from official.core import export_base
from official.modeling.hyperparams import base_config
from official.nlp.data import sentence_prediction_dataloader
......@@ -407,3 +409,48 @@ class Tagging(export_base.ExportModule):
signatures[signature_key] = self.serve_examples.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
return signatures
class Translation(export_base.ExportModule):
"""The export module for the translation task."""
@dataclasses.dataclass
class Params(base_config.Config):
sentencepiece_model_path: str = ""
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
self._sp_tokenizer = tf_text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(),
add_eos=True)
try:
empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy()
except tf.errors.InternalError:
raise ValueError(
"EOS token not in tokenizer vocab."
"Please make sure the tokenizer generates a single token for an "
"empty string.")
self._eos_id = empty_str_tokenized.item()
@tf.function
def serve(self, inputs) -> Dict[str, tf.Tensor]:
return self.inference_step(inputs)
@tf.function
def serve_text(self, text: tf.Tensor) -> Dict[str, tf.Tensor]:
tokenized = self._sp_tokenizer.tokenize(text).to_tensor(0)
return self._sp_tokenizer.detokenize(
self.serve({"inputs": tokenized})["outputs"])
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
valid_keys = ("serve_text")
for func_key, signature_key in function_keys.items():
if func_key not in valid_keys:
raise ValueError("Invalid function key for the module: %s with key %s. "
"Valid keys are: %s" %
(self.__class__, func_key, valid_keys))
if func_key == "serve_text":
signatures[signature_key] = self.serve_text.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="text"))
return signatures
......@@ -15,8 +15,11 @@
"""Tests for nlp.serving.serving_modules."""
import os
from absl.testing import parameterized
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.serving import serving_modules
......@@ -24,6 +27,7 @@ from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
from official.nlp.tasks import translation
def _create_fake_serialized_examples(features_dict):
......@@ -59,6 +63,33 @@ def _create_fake_vocab_file(vocab_file_path):
outfile.write("\n".join(tokens))
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = " ".join([
f"--input={input_path}", f"--vocab_size={vocab_size}",
"--character_coverage=0.995",
f"--model_prefix={model_path}", "--model_type=bpe",
"--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2"
])
SentencePieceTrainer.Train(argstr)
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, "w") as f:
for l in lines:
f.write("{}\n".format(l))
def _make_sentencepeice(output_dir):
src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"]
sentencepeice_input_path = os.path.join(output_dir, "inputs.txt")
_generate_line_file(sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(output_dir, "sp")
_train_sentencepiece(sentencepeice_input_path, 11, sentencepeice_model_prefix)
sentencepeice_model_path = "{}.model".format(sentencepeice_model_prefix)
return sentencepeice_model_path
class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
......@@ -312,6 +343,31 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
def test_translation(self):
sp_path = _make_sentencepeice(self.get_temp_dir())
encdecoder = translation.EncDecoder(
num_attention_heads=4, intermediate_size=256)
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=encdecoder,
decoder=encdecoder,
embedding_width=256,
padded_decode=False,
decode_max_length=100),
sentencepiece_model_path=sp_path,
)
task = translation.TranslationTask(config)
model = task.build_model()
params = serving_modules.Translation.Params(
sentencepiece_model_path=sp_path)
export_module = serving_modules.Translation(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve_text": "serving_default"
})
outputs = functions["serving_default"](tf.constant(["abcd", "ef gh"]))
self.assertEqual(outputs.shape, (2,))
self.assertEqual(outputs.dtype, tf.string)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册