提交 44c55806 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Updated bucket_ops to support different batch sizes per bucket.

(can be used to have larger batch sizes for smaller input sequences.)
Change: 144621786
上级 acdbd689
......@@ -115,8 +115,10 @@ def bucket(tensors,
tensors: The list or dictionary of tensors, representing a single element,
to bucket. Nested lists are not supported.
which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`.
batch_size: The new batch size pulled from the queue
(python int or int32 scalar).
batch_size: The new batch size pulled from the queue (all queues will have
the same size). If a list is passed in then each bucket will have a
different batch_size.
(python int, int32 scalar or iterable of integers of length num_buckets).
num_buckets: A python integer, the number of buckets.
num_threads: An integer. The number of threads enqueuing `tensors`.
capacity: An integer. The maximum number of minibatches in the top queue,
......@@ -145,8 +147,17 @@ def bucket(tensors,
Raises:
ValueError: If the `shapes` are not specified, and cannot be
inferred from the elements of `tensors`.
inferred from the elements of `tensors` or if batch_size is a sequence
but it's length != num_buckets.
"""
batch_size_per_bucket = False
if isinstance(batch_size, (list, tuple)):
batch_size_per_bucket = True
if len(batch_size) != num_buckets:
raise ValueError(
"If batch_size is a list it must have num_buckets elements")
else:
batch_size = [batch_size] * num_buckets
tensor_list = _as_tensor_list(tensors)
with ops.name_scope(name, "bucket", tensor_list) as name:
tensor_list = _validate_bucket(tensor_list)
......@@ -154,11 +165,12 @@ def bucket(tensors,
tensor_list, enqueue_many=False, keep_input=constant_op.constant(True))
# Round-trip batch_size to a tensor, and possibly back
batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int32, name="batch_size")
static_batch_size = tensor_util.constant_value(batch_size)
batch_size = (static_batch_size if static_batch_size is not None else
batch_size)
for i, bucket_batch_size in enumerate(batch_size):
bucket_batch_size = ops.convert_to_tensor(
bucket_batch_size, dtype=dtypes.int32, name="batch_size")
static_batch_size = tensor_util.constant_value(bucket_batch_size)
batch_size[i] = (static_batch_size if static_batch_size is not None else
bucket_batch_size)
types = _dtypes([tensor_list])
shapes = _shapes([tensor_list], shapes, enqueue_many=False)
......@@ -179,8 +191,9 @@ def bucket(tensors,
shared_name=shared_name_i,
name="bucket_queue_%d" % i))
maybe_static_batch_size = (None if allow_smaller_final_batch else
static_batch_size)
maybe_static_batch_size = (
None if (allow_smaller_final_batch or batch_size_per_bucket)
else static_batch_size)
bucket_shapes = [
tensor_shape.vector(maybe_static_batch_size).concatenate(s)
......@@ -229,9 +242,9 @@ def bucket(tensors,
enqueues_to_top = [
top_queue.enqueue(
[constant_op.constant(i)] + which_dequeue(q)(
batch_size, name="read_bucket_%d" % i),
bs, name="read_bucket_%d" % i),
name="enqueue_from_bucket_%d" % i)
for i, q in enumerate(bucket_queues)
for i, (q, bs) in enumerate(zip(bucket_queues, batch_size))
]
for i, q in enumerate(bucket_queues):
......@@ -284,8 +297,10 @@ def bucket_by_sequence_length(input_length,
input_length: `int32` scalar `Tensor`, the sequence length of tensors.
tensors: The list or dictionary of tensors, representing a single element,
to bucket. Nested lists are not supported.
batch_size: The new batch size pulled from the queue
(python int or int32 scalar).
batch_size: The new batch size pulled from the queue (all queues will have
the same size). If a list is passed in then each bucket will have a
different batch_size.
(python int, int32 scalar or iterable of integers of length num_buckets).
bucket_boundaries: int list, increasing non-negative numbers.
The edges of the buckets to use when bucketing tensors. Two extra buckets
are created, one for `input_length < bucket_boundaries[0]` and
......@@ -317,7 +332,8 @@ def bucket_by_sequence_length(input_length,
Raises:
TypeError: if `bucket_boundaries` is not a list of python integers.
ValueError: if `bucket_boundaries` is empty or contains non-increasing
values.
values or if batch_size is a list and it's length doesn't equal the number
of buckets.
"""
tensor_list = _as_tensor_list(tensors)
if not isinstance(bucket_boundaries, (list, tuple)):
......
......@@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
......@@ -139,6 +140,51 @@ class BucketTest(test.TestCase):
self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort])
self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort])
def testBatchSizePerBucket(self):
which_bucket = control_flow_ops.cond(self.scalar_int < 5,
lambda: constant_op.constant(0),
lambda: constant_op.constant(1))
batch_sizes = [5, 10]
bucketed_dynamic = bucket_ops.bucket(
tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
which_bucket=which_bucket,
num_buckets=2,
batch_size=batch_sizes,
num_threads=1,
dynamic_pad=True)
# Check shape inference on bucketing outputs
self.assertAllEqual(
[[None], [None, None], [None, 3]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
with self.test_session() as sess:
for v in range(15):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
self.unk_int64_feed: v * [v],
self.vec3_str_feed: 3 * [str(v)]
})
self.start_queue_runners(sess)
# Get two minibatches (one with small values, one with large).
bucketed_values_0 = sess.run(bucketed_dynamic)
bucketed_values_1 = sess.run(bucketed_dynamic)
# Figure out which output has the small values
if bucketed_values_0[0] < 5:
bucketed_values_large, bucketed_values_small = (bucketed_values_1,
bucketed_values_0)
else:
bucketed_values_small, bucketed_values_large = (bucketed_values_0,
bucketed_values_1)
# Ensure bucket 0 was used for all minibatch entries.
self.assertAllEqual(0, bucketed_values_small[0])
self.assertAllEqual(1, bucketed_values_large[0])
# Check that the batch sizes differ per bucket
self.assertEqual(5, len(bucketed_values_small[1][0]))
self.assertEqual(10, len(bucketed_values_large[1][0]))
def testEvenOddBuckets(self):
which_bucket = (self.scalar_int % 2)
bucketed_dynamic = bucket_ops.bucket(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册