diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 9566d03211577bab93c486b4a0b1bd57d6e74105..3c314e2f285ff8de45a500807e4097b42fb3440d 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -40,6 +40,18 @@ cuda_py_test( ], ) +cuda_py_test( + name = "loss_test", + size = "medium", + srcs = ["python/kernel_tests/loss_test.py"], + additional_deps = [ + ":seq2seq_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "seq2seq_test", size = "medium", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index f99de76f175dc8f485423d17a27de9d176aaa26d..95560fb254d3065b3d56d128370ce610bafdb46f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -20,14 +20,58 @@ from __future__ import division from __future__ import print_function # pylint: enable=unused-import +import numpy as np import tensorflow as tf - class LossTest(tf.test.TestCase): - def testLoss(self): - pass + def testSequenceLoss(self): + with self.test_session() as sess: + with tf.variable_scope("root", + initializer=tf.constant_initializer(0.5)) as varscope: + batch_size = 2 + sequence_length = 3 + number_of_classes = 5 + logits = [tf.constant(i + 0.5, shape=[batch_size, number_of_classes]) + for i in range(sequence_length)] + logits = tf.stack(logits, axis=1) + targets = [tf.constant(i, tf.int32, shape=[batch_size]) for i in + range(sequence_length)] + targets = tf.stack(targets, axis=1) + weights = [tf.constant(1.0, shape=[batch_size]) for i in + range(sequence_length)] + weights = tf.stack(weights, axis=1) + + average_loss_per_example = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=True, + average_across_batch=True) + res = sess.run(average_loss_per_example) + self.assertAllClose(1.60944, res) + + average_loss_per_sequence = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=False, + average_across_batch=True) + res = sess.run(average_loss_per_sequence) + compare_per_sequence = np.ones((sequence_length)) * 1.60944 + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=True, + average_across_batch=False) + res = sess.run(average_loss_per_batch) + compare_per_batch = np.ones((batch_size)) * 1.60944 + self.assertAllClose(compare_per_batch, res) + total_loss = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=False, + average_across_batch=False) + res = sess.run(total_loss) + compare_total = np.ones((batch_size, sequence_length)) * 1.60944 + self.assertAllClose(compare_total, res) if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py index b8a33b3f6f66b15bdb04e9ecf9875e70d78c22b9..bb8711126667b387f02531f878e64da9717984b4 100644 --- a/tensorflow/contrib/seq2seq/python/ops/loss.py +++ b/tensorflow/contrib/seq2seq/python/ops/loss.py @@ -13,18 +13,88 @@ # limitations under the License. # ============================================================================== -"""Seq2seq loss operations for use in neural networks. +"""Seq2seq loss operations for use in sequence models. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import math_ops +__all__ = ["sequence_loss"] -__all__ = ["seq2seq_loss"] +def sequence_loss(logits, targets, weights, + average_across_timesteps=True, average_across_batch=True, + softmax_loss_function=None, name=None): + """Weighted cross-entropy loss for a sequence of logits (per example). + Args: + logits: A 3D Tensor of shape + [batch_size x sequence_length x num_decoder_symbols] and dtype float. + The logits correspond to the prediction across all classes at each + timestep. + targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype + int. The target represents the true class at each timestep. + weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype + float. Weights constitutes the weighting of each prediction in the + sequence. When using weights as masking set all valid timesteps to 1 and + all padded timesteps to 0. + average_across_timesteps: If set, sum the cost across the sequence + dimension and divide by the cost by the total label weight across + timesteps. + average_across_batch: If set, sum the cost across the batch dimension and + divide the returned cost by the batch size. + softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch + to be used instead of the standard softmax (the default if this is None). + name: Optional name for this operation, defaults to "sequence_loss". -def seq2seq_loss(*args, **kwargs): - pass + Returns: + A scalar float Tensor: The average log-perplexity per symbol (weighted). + + Raises: + ValueError: logits does not have 3 dimensions or targets does not have 2 + dimensions or weights does not have 2 dimensions. + """ + if len(logits.get_shape()) != 3: + raise ValueError("Logits must be a " + "[batch_size x sequence_length x logits] tensor") + if len(targets.get_shape()) != 2: + raise ValueError("Targets must be a [batch_size x sequence_length] " + "tensor") + if len(weights.get_shape()) != 2: + raise ValueError("Weights must be a [batch_size x sequence_length] " + "tensor") + with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): + num_classes = array_ops.shape(logits)[2] + probs_flat = array_ops.reshape(logits, [-1, num_classes]) + targets = array_ops.reshape(targets, [-1]) + if softmax_loss_function is None: + crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( + labels=targets, logits=probs_flat) + else: + crossent = softmax_loss_function(probs_flat, targets) + crossent = crossent * array_ops.reshape(weights, [-1]) + if average_across_timesteps and average_across_batch: + crossent = math_ops.reduce_sum(crossent) + total_size = math_ops.reduce_sum(weights) + total_size += 1e-12 # to avoid division by 0 for all-0 weights + crossent /= total_size + else: + batch_size = array_ops.shape(logits)[0] + sequence_length = array_ops.shape(logits)[1] + crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) + if average_across_timesteps and not average_across_batch: + crossent = math_ops.reduce_sum(crossent, axis=[1]) + total_size = math_ops.reduce_sum(weights, axis=[1]) + total_size += 1e-12 # to avoid division by 0 for all-0 weights + crossent /= total_size + if not average_across_timesteps and average_across_batch: + crossent = math_ops.reduce_sum(crossent, axis=[0]) + total_size = math_ops.reduce_sum(weights, axis=[0]) + total_size += 1e-12 # to avoid division by 0 for all-0 weights + crossent /= total_size + return crossent