提交 03abac7f 编写于 作者: E Eugene Brevdo 提交者: TensorFlower Gardener

Updates to RNNCells to allow easy storage of attention TensorArray in the state.

The main change is that RNNCells that wrap other RNNCells now override self.zero_state to call the wrapped cell's zero_state and then (maybe) perform some post-processing... instead of relying on the state_size property to provide all information about the state.

Also made zero_state calls create ops inside their own name scope.
Change: 150413265
上级 9cc50983
......@@ -489,6 +489,10 @@ class OutputProjectionWrapper(RNNCell):
def output_size(self):
return self._output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
"""Run the cell and output projection on inputs, starting from state."""
output, res_state = self._cell(inputs, state)
......@@ -533,6 +537,10 @@ class InputProjectionWrapper(RNNCell):
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
"""Run the input projection and then the cell."""
# Default scope: "InputProjectionWrapper"
......@@ -665,6 +673,10 @@ class DropoutWrapper(RNNCell):
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def _variational_recurrent_dropout_value(
self, index, value, noise, keep_prob):
"""Performs dropout given the pre-calculated noise tensor."""
......@@ -729,6 +741,10 @@ class ResidualWrapper(RNNCell):
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
"""Run the cell and add its inputs to its outputs.
......@@ -778,6 +794,10 @@ class DeviceWrapper(RNNCell):
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
"""Run the cell on specified device."""
with ops.device(self._device):
......@@ -830,6 +850,10 @@ class EmbeddingWrapper(RNNCell):
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
"""Run the cell on embedded inputs."""
with _checked_scope(self, scope or "embedding_wrapper", reuse=self._reuse):
......@@ -899,6 +923,15 @@ class MultiRNNCell(RNNCell):
def output_size(self):
return self._cells[-1].output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
if self._state_is_tuple:
return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
else:
# We know here that state_size of each cell is not a tuple and
# presumably does not contain TensorArrays or anything else fancy
return super(MultiRNNCell, self).zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
with vs.variable_scope(scope or "multi_rnn_cell"):
......
......@@ -1460,6 +1460,10 @@ class CompiledWrapper(core_rnn_cell.RNNCell):
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
if self._compile_stateful:
compile_ops = True
......
......@@ -249,8 +249,14 @@ def dynamic_decode(decoder,
# Copy through states past finish
def _maybe_copy_state(new, cur):
return (new if isinstance(cur, tensor_array_ops.TensorArray) else
array_ops.where(finished, cur, new))
# TensorArrays and scalar states get passed through.
if isinstance(cur, tensor_array_ops.TensorArray):
pass_through = True
else:
new.set_shape(cur.shape)
pass_through = (new.shape.ndims == 0)
return new if pass_through else array_ops.where(finished, cur, new)
if impute_finished:
next_state = nest.map_structure(
_maybe_copy_state, decoder_state, state)
......
......@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
......@@ -44,6 +45,9 @@ __all__ = [
]
_zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access
class AttentionMechanism(object):
pass
......@@ -478,6 +482,13 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell):
cell_state=self._cell.state_size,
attention=self._attention_size)
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return DynamicAttentionWrapperState(
cell_state=self._cell.zero_state(batch_size, dtype),
attention=_zero_state_tensors(
self._attention_size, batch_size, dtype))
def __call__(self, inputs, state, scope=None):
"""Perform a step of attention-wrapped RNN.
......
......@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
......@@ -46,6 +47,28 @@ def _state_size_with_prefix(state_size, prefix=None):
return result_state_size
def _zero_state_tensors(state_size, batch_size, dtype):
"""Create tensors of zeros based on state_size, batch_size, and dtype."""
if nest.is_sequence(state_size):
state_size_flat = nest.flatten(state_size)
zeros_flat = [
array_ops.zeros(
array_ops.stack(_state_size_with_prefix(
s, prefix=[batch_size])),
dtype=dtype) for s in state_size_flat
]
for s, z in zip(state_size_flat, zeros_flat):
z.set_shape(_state_size_with_prefix(s, prefix=[None]))
zeros = nest.pack_sequence_as(structure=state_size,
flat_sequence=zeros_flat)
else:
zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
zeros = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype)
zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))
return zeros
class _RNNCell(object):
"""Abstract object representing an RNN cell.
......@@ -119,22 +142,6 @@ class _RNNCell(object):
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size x s]` for each s in `state_size`.
"""
state_size = self.state_size
if nest.is_sequence(state_size):
state_size_flat = nest.flatten(state_size)
zeros_flat = [
array_ops.zeros(
array_ops.stack(_state_size_with_prefix(
s, prefix=[batch_size])),
dtype=dtype) for s in state_size_flat
]
for s, z in zip(state_size_flat, zeros_flat):
z.set_shape(_state_size_with_prefix(s, prefix=[None]))
zeros = nest.pack_sequence_as(structure=state_size,
flat_sequence=zeros_flat)
else:
zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
zeros = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype)
zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))
return zeros
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
state_size = self.state_size
return _zero_state_tensors(state_size, batch_size, dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册