提交 b10c65c2 编写于 作者: E Eugene Brevdo 提交者: TensorFlower Gardener

Remove use of conitional pass-through from dynamic_rnn in favor of just

calling select() at each step.  conditionals add a bunch of extra ops
that slow things down; and generally the size of the input tensor into
dynamic_rnn matches max_sequence_length so they provide no benefit.

Before change, benchmarks:

Calculation: Static Unroll with Dynamic Flow LSTM vs. Dynamic Unroll LSTM
batch    max_t   units   gpu     dt(static)      dt(dynamic)     dt(dynamic)/dt(static)
256      50      512     False   1.795002        1.774248        0.988437
256      50      512     True    0.186828        0.200752        1.074525
256      50      256     False   0.597320        0.750226        1.255986
256      50      256     True    0.082047        0.091411        1.114130
256      50      128     False   0.250596        0.238233        0.950666
256      50      128     True    0.056480        0.063086        1.116960

After change, benchmarks:

Calculation: Static Unroll with Dynamic Flow LSTM vs. Dynamic Unroll LSTM                                                                                              batch    max_t   units   gpu     dt(static)      dt(dynamic)     dt(dynamic)/dt(static)
256      50      512     False   1.723348        1.763019        1.023020
256      50      512     True    0.186794        0.196334        1.051072
256      50      256     False   0.644540        0.704506        1.093036
256      50      256     True    0.082274        0.087785        1.066985
256      50      128     False   0.241971        0.234559        0.969368
256      50      128     True    0.056356        0.059771        1.060611

Basically expect a more significant decrease in GPU step time when the matrices are smaller.
Change: 117254684
上级 c9c341e4
......@@ -957,7 +957,7 @@ def graph_creation_static_vs_dynamic_rnn_benchmark(max_time):
def _timer(sess, ops):
# Warm in
for _ in range(5):
for _ in range(2):
sess.run(ops)
# Timing run
......@@ -1100,24 +1100,24 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
def main(_):
print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
for max_time in (1, 25, 50, 100, 200):
for max_time in (1, 25, 50):
graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
print("Calculation: Static Unroll with Dynamic Flow LSTM "
"vs. Dynamic Unroll LSTM")
print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) "
"\t dt(dynamic)/dt(static)")
for use_gpu in (False, True):
for batch_size in (256, 512):
for max_time in (50, 100):
for num_units in (512, 256, 128):
for batch_size in (256,):
for max_time in (50,):
for num_units in (512, 256, 128):
for use_gpu in (False, True):
static_vs_dynamic_rnn_benchmark(
batch_size, max_time, num_units, use_gpu)
print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap")
print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap")
for batch_size in (256, 512):
for max_time in (50, 100):
for max_time in (100,):
for num_units in (512, 256, 128):
dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units)
......
......@@ -182,11 +182,11 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
zero_output, state, call_cell):
zero_output, state, call_cell, skip_conditionals=False):
"""Calculate one step of a dynamic RNN minibatch.
Returns an (output, state) pair conditioned on the sequence_lengths.
The pseudocode is something like:
When skip_conditionals=False, the pseudocode is something like:
if t >= max_sequence_length:
return (zero_output, state)
......@@ -216,6 +216,10 @@ def _rnn_step(
call_cell: lambda returning tuple of (new_output, new_state) where
new_output is a `Tensor` matrix of shape [batch_size, output_size]
new_state is a `Tensor` matrix of shape [batch_size, state_size]
skip_conditionals: Python bool, whether to skip using the conditional
calculations. This is useful for dynamic_rnn, where the input tensor
matches max_sequence_length, and using conditionals just slows
everything down.
Returns:
A tuple of (final_output, final_state) as given by the pseudocode above:
......@@ -225,8 +229,15 @@ def _rnn_step(
# Step 1: determine whether we need to call_cell or not
empty_update = lambda: (zero_output, state)
state_shape = state.get_shape()
output, new_state = control_flow_ops.cond(
time < max_sequence_length, call_cell, empty_update)
if skip_conditionals:
# Skip using conditionals: calculate the RNN step at all time
# steps. This is faster for dynamic_rnn, where the time steps
# should cap out at max_sequence_length anyway.
output, new_state = call_cell()
else:
output, new_state = control_flow_ops.cond(
time < max_sequence_length, call_cell, empty_update)
# Step 2: determine whether we need to copy through state and/or outputs
existing_output_state = lambda: (output, new_state)
......@@ -239,8 +250,17 @@ def _rnn_step(
return (math_ops.select(copy_cond, zero_output, output),
math_ops.select(copy_cond, state, new_state))
(output, state) = control_flow_ops.cond(
time < min_sequence_length, existing_output_state, copy_through)
# TODO(ebrevdo): skipping these conditionals may cause a slowdown,
# but benefits from removing cond() and its gradient. We should
# profile with and without this switch here.
if skip_conditionals:
# Skip using conditionals: perform the selective copy at all time
# steps. This is usually faster.
(output, state) = copy_through()
else:
(output, state) = control_flow_ops.cond(
time < min_sequence_length, existing_output_state, copy_through)
output.set_shape(zero_output.get_shape())
state.set_shape(state_shape)
return (output, state)
......@@ -549,8 +569,14 @@ def _dynamic_rnn_loop(
if sequence_length is not None:
(output, new_state) = _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
zero_output, state, call_cell)
time=time,
sequence_length=sequence_length,
min_sequence_length=min_sequence_length,
max_sequence_length=max_sequence_length,
zero_output=zero_output,
state=state,
call_cell=call_cell,
skip_conditionals=True)
else:
(output, new_state) = call_cell()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册