提交 f3641f23 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Change Seq2SeqTransformer inputs to dictionary.

PiperOrigin-RevId: 338309524
上级 09b9dad7
......@@ -130,9 +130,9 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Calculate target logits or inferred target sequences.
Args:
inputs: input tensor list of size 1 or 2.
First item, inputs: int tensor with shape [batch_size, input_length].
Second item (optional), targets: None or int tensor with shape
inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape [batch_size, input_length].
Feature `targets` (optional): None or int tensor with shape
[batch_size, target_length].
Returns:
......@@ -147,12 +147,8 @@ class Seq2SeqTransformer(tf.keras.Model):
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
inputs = inputs if isinstance(inputs, list) else [inputs]
if len(inputs) == 2:
sources, targets = inputs[0], inputs[1]
else:
# Decoding path.
sources, targets = inputs[0], None
sources = inputs["inputs"]
targets = inputs.get("targets", None)
attention_bias = model_utils.get_padding_bias(sources)
attention_bias = tf.cast(attention_bias, self._dtype)
# Prepare inputs to the layer stack by adding positional encodings and
......
......@@ -82,15 +82,15 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
fake_inputs = [np.zeros((batch_size, decode_max_length), dtype=np.int32)]
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = [
np.zeros((batch_size, decode_max_length), dtype=np.int32),
np.zeros((batch_size, 8), dtype=np.int32)
]
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32),
targets=np.zeros((batch_size, 8), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
......@@ -108,7 +108,7 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
@tf.function
def serve(self, inputs):
return self.model.call([inputs])
return self.model.call(dict(inputs=inputs))
save_module = SaveModule(model)
if padded_decode:
......
......@@ -70,7 +70,8 @@ def _create_model(params, is_train):
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = models.Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train)
logits = internal_model(
dict(inputs=inputs, targets=targets), training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
......@@ -90,7 +91,7 @@ def _create_model(params, is_train):
dtype="int64",
name="inputs")
internal_model = models.Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train)
ret = internal_model(dict(inputs=inputs), training=is_train)
outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册