提交 24466c2e 编写于 作者: A Alexander Rosenberg Johansen 提交者: drpngx

List of 2Ds -> 3D Tensor, seq2seq_loss (#4382)

* sequence loss function for seq2seq loss

* loss raises and docstring updated

* detailed docstring

* moving initializer one line down

* change order of arguments, sparse_crossent

* Use tf.stack instead of tf.pack.
上级 041b1762
......@@ -40,6 +40,18 @@ cuda_py_test(
],
)
cuda_py_test(
name = "loss_test",
size = "medium",
srcs = ["python/kernel_tests/loss_test.py"],
additional_deps = [
":seq2seq_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "seq2seq_test",
size = "medium",
......
......@@ -20,14 +20,58 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import numpy as np
import tensorflow as tf
class LossTest(tf.test.TestCase):
def testLoss(self):
pass
def testSequenceLoss(self):
with self.test_session() as sess:
with tf.variable_scope("root",
initializer=tf.constant_initializer(0.5)) as varscope:
batch_size = 2
sequence_length = 3
number_of_classes = 5
logits = [tf.constant(i + 0.5, shape=[batch_size, number_of_classes])
for i in range(sequence_length)]
logits = tf.stack(logits, axis=1)
targets = [tf.constant(i, tf.int32, shape=[batch_size]) for i in
range(sequence_length)]
targets = tf.stack(targets, axis=1)
weights = [tf.constant(1.0, shape=[batch_size]) for i in
range(sequence_length)]
weights = tf.stack(weights, axis=1)
average_loss_per_example = tf.contrib.seq2seq.sequence_loss(
logits, targets, weights,
average_across_timesteps=True,
average_across_batch=True)
res = sess.run(average_loss_per_example)
self.assertAllClose(1.60944, res)
average_loss_per_sequence = tf.contrib.seq2seq.sequence_loss(
logits, targets, weights,
average_across_timesteps=False,
average_across_batch=True)
res = sess.run(average_loss_per_sequence)
compare_per_sequence = np.ones((sequence_length)) * 1.60944
self.assertAllClose(compare_per_sequence, res)
average_loss_per_batch = tf.contrib.seq2seq.sequence_loss(
logits, targets, weights,
average_across_timesteps=True,
average_across_batch=False)
res = sess.run(average_loss_per_batch)
compare_per_batch = np.ones((batch_size)) * 1.60944
self.assertAllClose(compare_per_batch, res)
total_loss = tf.contrib.seq2seq.sequence_loss(
logits, targets, weights,
average_across_timesteps=False,
average_across_batch=False)
res = sess.run(total_loss)
compare_total = np.ones((batch_size, sequence_length)) * 1.60944
self.assertAllClose(compare_total, res)
if __name__ == '__main__':
tf.test.main()
......@@ -13,18 +13,88 @@
# limitations under the License.
# ==============================================================================
"""Seq2seq loss operations for use in neural networks.
"""Seq2seq loss operations for use in sequence models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import math_ops
__all__ = ["sequence_loss"]
__all__ = ["seq2seq_loss"]
def sequence_loss(logits, targets, weights,
average_across_timesteps=True, average_across_batch=True,
softmax_loss_function=None, name=None):
"""Weighted cross-entropy loss for a sequence of logits (per example).
Args:
logits: A 3D Tensor of shape
[batch_size x sequence_length x num_decoder_symbols] and dtype float.
The logits correspond to the prediction across all classes at each
timestep.
targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
int. The target represents the true class at each timestep.
weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
float. Weights constitutes the weighting of each prediction in the
sequence. When using weights as masking set all valid timesteps to 1 and
all padded timesteps to 0.
average_across_timesteps: If set, sum the cost across the sequence
dimension and divide by the cost by the total label weight across
timesteps.
average_across_batch: If set, sum the cost across the batch dimension and
divide the returned cost by the batch size.
softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
name: Optional name for this operation, defaults to "sequence_loss".
def seq2seq_loss(*args, **kwargs):
pass
Returns:
A scalar float Tensor: The average log-perplexity per symbol (weighted).
Raises:
ValueError: logits does not have 3 dimensions or targets does not have 2
dimensions or weights does not have 2 dimensions.
"""
if len(logits.get_shape()) != 3:
raise ValueError("Logits must be a "
"[batch_size x sequence_length x logits] tensor")
if len(targets.get_shape()) != 2:
raise ValueError("Targets must be a [batch_size x sequence_length] "
"tensor")
if len(weights.get_shape()) != 2:
raise ValueError("Weights must be a [batch_size x sequence_length] "
"tensor")
with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
num_classes = array_ops.shape(logits)[2]
probs_flat = array_ops.reshape(logits, [-1, num_classes])
targets = array_ops.reshape(targets, [-1])
if softmax_loss_function is None:
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=probs_flat)
else:
crossent = softmax_loss_function(probs_flat, targets)
crossent = crossent * array_ops.reshape(weights, [-1])
if average_across_timesteps and average_across_batch:
crossent = math_ops.reduce_sum(crossent)
total_size = math_ops.reduce_sum(weights)
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
else:
batch_size = array_ops.shape(logits)[0]
sequence_length = array_ops.shape(logits)[1]
crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
if average_across_timesteps and not average_across_batch:
crossent = math_ops.reduce_sum(crossent, axis=[1])
total_size = math_ops.reduce_sum(weights, axis=[1])
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
if not average_across_timesteps and average_across_batch:
crossent = math_ops.reduce_sum(crossent, axis=[0])
total_size = math_ops.reduce_sum(weights, axis=[0])
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
return crossent
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册