提交 9e601230 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add Frequency Block processing to GridLSTMCell.

Change: 136742352
上级 876c4aa5
......@@ -103,7 +103,7 @@ class RNNCellTest(tf.test.TestCase):
cell = tf.contrib.rnn.GridLSTMCell(
num_units=num_units, feature_size=feature_size,
frequency_skip=frequency_skip, forget_bias=1.0,
num_frequency_blocks=num_shifts,
num_frequency_blocks=[num_shifts],
couple_input_forget_gates=True,
state_is_tuple=True)
inputs = tf.constant(np.array([[1., 1., 1., 1.],
......@@ -129,7 +129,54 @@ class RNNCellTest(tf.test.TestCase):
self.assertTrue(
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
self.assertTrue(float(np.linalg.norm(
(res[1].state_f00_c[0, :] - res[1].state_f00_c[i, :])))
(res[1].state_f00_b00_c[0, :] - res[1].state_f00_b00_c[i, :])))
> 1e-6)
def testGridLSTMCellWithFrequencyBlocks(self):
with self.test_session() as sess:
num_units = 8
batch_size = 3
input_size = 4
feature_size = 2
frequency_skip = 1
num_frequency_blocks = [1, 1]
total_blocks = num_frequency_blocks[0] + num_frequency_blocks[1]
start_freqindex_list = [0, 2]
end_freqindex_list = [2, 4]
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
cell = tf.contrib.rnn.GridLSTMCell(
num_units=num_units, feature_size=feature_size,
frequency_skip=frequency_skip, forget_bias=1.0,
num_frequency_blocks=num_frequency_blocks,
start_freqindex_list=start_freqindex_list,
end_freqindex_list=end_freqindex_list,
couple_input_forget_gates=True,
state_is_tuple=True)
inputs = tf.constant(np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]],
dtype=np.float32), dtype=tf.float32)
state_value = tf.constant(
0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=tf.float32)
init_state = cell.state_tuple_type(
*([state_value, state_value] * total_blocks))
output, state = cell(inputs, init_state)
sess.run([tf.initialize_all_variables()])
res = sess.run([output, state])
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
self.assertEqual(res[0].shape,
(batch_size, num_units * total_blocks * 2))
for ss in res[1]:
self.assertEqual(ss.shape, (batch_size, num_units))
# Different inputs so different outputs and states
for i in range(1, batch_size):
self.assertTrue(
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
self.assertTrue(float(np.linalg.norm(
(res[1].state_f00_b00_c[0, :] - res[1].state_f00_b00_c[i, :])))
> 1e-6)
def testGridLstmCellWithCoupledInputForgetGates(self):
......@@ -162,7 +209,7 @@ class RNNCellTest(tf.test.TestCase):
cell = tf.contrib.rnn.GridLSTMCell(
num_units=num_units, feature_size=feature_size,
frequency_skip=frequency_skip, forget_bias=1.0,
num_frequency_blocks=num_shifts,
num_frequency_blocks=[num_shifts],
couple_input_forget_gates=True,
state_is_tuple=state_is_tuple)
inputs = tf.constant(np.array([[1., 1., 1., 1.],
......@@ -238,7 +285,7 @@ class RNNCellTest(tf.test.TestCase):
num_units=num_units, feature_size=feature_size,
share_time_frequency_weights=True,
frequency_skip=frequency_skip, forget_bias=1.0,
num_frequency_blocks=num_shifts)
num_frequency_blocks=[num_shifts])
inputs = tf.constant(np.array([[1.0, 1.1, 1.2, 1.3],
[2.0, 2.1, 2.2, 2.3],
[3.0, 3.1, 3.2, 3.3]],
......@@ -305,7 +352,7 @@ class RNNCellTest(tf.test.TestCase):
num_units=num_units, feature_size=feature_size,
share_time_frequency_weights=True,
frequency_skip=frequency_skip, forget_bias=1.0,
num_frequency_blocks=num_shifts,
num_frequency_blocks=[num_shifts],
backward_slice_offset=1)
inputs = tf.constant(np.array([[1.0, 1.1, 1.2, 1.3],
[2.0, 2.1, 2.2, 2.3],
......
......@@ -431,7 +431,9 @@ class GridLSTMCell(rnn_cell.RNNCell):
cell_clip=None, initializer=None,
num_unit_shards=1, forget_bias=1.0,
feature_size=None, frequency_skip=None,
num_frequency_blocks=1,
num_frequency_blocks=None,
start_freqindex_list=None,
end_freqindex_list=None,
couple_input_forget_gates=False,
state_is_tuple=False):
"""Initialize the parameters for an LSTM cell.
......@@ -455,14 +457,21 @@ class GridLSTMCell(rnn_cell.RNNCell):
the LSTM spans over.
frequency_skip: (optional) int, default None, The amount the LSTM filter
is shifted by in frequency.
num_frequency_blocks: (optional) int, default 1, The total number of
frequency blocks needed to cover the whole input feature.
num_frequency_blocks: [required] A list of frequency blocks needed to
cover the whole input feature splitting defined by start_freqindex_list
and end_freqindex_list.
start_freqindex_list: [optional], list of ints, default None, The
starting frequency index for each frequency block.
end_freqindex_list: [optional], list of ints, default None. The ending
frequency index for each frequency block.
couple_input_forget_gates: (optional) bool, default False, Whether to
couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
model parameters and computation cost.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
Raises:
ValueError: if the num_frequency_blocks list is not specified
"""
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
......@@ -478,20 +487,29 @@ class GridLSTMCell(rnn_cell.RNNCell):
self._forget_bias = forget_bias
self._feature_size = feature_size
self._frequency_skip = frequency_skip
self._num_frequency_blocks = int(num_frequency_blocks)
self._start_freqindex_list = start_freqindex_list
self._end_freqindex_list = end_freqindex_list
self._num_frequency_blocks = num_frequency_blocks
self._total_blocks = 0
if self._num_frequency_blocks is None:
raise ValueError("Must specify num_frequency_blocks")
for block_index in range(len(self._num_frequency_blocks)):
self._total_blocks += int(self._num_frequency_blocks[block_index])
if state_is_tuple:
state_names = ""
for freq_index in range(self._num_frequency_blocks):
name_prefix = "state_f%02d" % freq_index
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
for block_index in range(len(self._num_frequency_blocks)):
for freq_index in range(self._num_frequency_blocks[block_index]):
name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
self._state_tuple_type = collections.namedtuple(
"GridLSTMStateTuple", state_names.strip(","))
self._state_size = self._state_tuple_type(
*([num_units, num_units] * self._num_frequency_blocks))
*([num_units, num_units] * self._total_blocks))
else:
self._state_tuple_type = None
self._state_size = num_units * self._num_frequency_blocks * 2
self._output_size = num_units * self._num_frequency_blocks * 2
self._state_size = num_units * self._total_blocks * 2
self._output_size = num_units * self._total_blocks * 2
@property
def output_size(self):
......@@ -530,8 +548,14 @@ class GridLSTMCell(rnn_cell.RNNCell):
freq_inputs = self._make_tf_features(inputs)
with vs.variable_scope(scope or type(self).__name__,
initializer=self._initializer): # "GridLSTMCell"
m_out_lst, state_out_lst = self._compute(
freq_inputs, state, batch_size, state_is_tuple=self._state_is_tuple)
m_out_lst = []
state_out_lst = []
for block in range(len(freq_inputs)):
m_out_lst_current, state_out_lst_current = self._compute(
freq_inputs[block], block, state, batch_size,
state_is_tuple=self._state_is_tuple)
m_out_lst.extend(m_out_lst_current)
state_out_lst.extend(state_out_lst_current)
if self._state_is_tuple:
state_out = self._state_tuple_type(*state_out_lst)
else:
......@@ -539,12 +563,14 @@ class GridLSTMCell(rnn_cell.RNNCell):
m_out = array_ops.concat(1, m_out_lst)
return m_out, state_out
def _compute(self, freq_inputs, state, batch_size, state_prefix="state",
def _compute(self, freq_inputs, block, state, batch_size,
state_prefix="state",
state_is_tuple=True):
"""Run the actual computation of one step LSTM.
Args:
freq_inputs: list of Tensors, 2D, [batch, feature_size].
block: int, current frequency block index to process.
state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
the flag state_is_tuple.
batch_size: int32, batch size.
......@@ -566,57 +592,57 @@ class GridLSTMCell(rnn_cell.RNNCell):
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
concat_w_f = _get_concat_variable(
"W_f", [actual_input_size + 2 * self._num_units,
num_gates * self._num_units],
"W_f_%d" % block, [actual_input_size + 2 * self._num_units,
num_gates * self._num_units],
dtype, self._num_unit_shards)
b_f = vs.get_variable(
"B_f", shape=[num_gates * self._num_units],
"B_f_%d" % block, shape=[num_gates * self._num_units],
initializer=init_ops.zeros_initializer, dtype=dtype)
if not self._share_time_frequency_weights:
concat_w_t = _get_concat_variable(
"W_t", [actual_input_size + 2 * self._num_units,
num_gates * self._num_units],
"W_t_%d" % block, [actual_input_size + 2 * self._num_units,
num_gates * self._num_units],
dtype, self._num_unit_shards)
b_t = vs.get_variable(
"B_t", shape=[num_gates * self._num_units],
"B_t_%d" % block, shape=[num_gates * self._num_units],
initializer=init_ops.zeros_initializer, dtype=dtype)
if self._use_peepholes:
# Diagonal connections
if not self._couple_input_forget_gates:
w_f_diag_freqf = vs.get_variable(
"W_F_diag_freqf", shape=[self._num_units], dtype=dtype)
"W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
w_f_diag_freqt = vs.get_variable(
"W_F_diag_freqt", shape=[self._num_units], dtype=dtype)
"W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype)
w_i_diag_freqf = vs.get_variable(
"W_I_diag_freqf", shape=[self._num_units], dtype=dtype)
"W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
w_i_diag_freqt = vs.get_variable(
"W_I_diag_freqt", shape=[self._num_units], dtype=dtype)
"W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
w_o_diag_freqf = vs.get_variable(
"W_O_diag_freqf", shape=[self._num_units], dtype=dtype)
"W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
w_o_diag_freqt = vs.get_variable(
"W_O_diag_freqt", shape=[self._num_units], dtype=dtype)
"W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
if not self._share_time_frequency_weights:
if not self._couple_input_forget_gates:
w_f_diag_timef = vs.get_variable(
"W_F_diag_timef", shape=[self._num_units], dtype=dtype)
"W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
w_f_diag_timet = vs.get_variable(
"W_F_diag_timet", shape=[self._num_units], dtype=dtype)
"W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
w_i_diag_timef = vs.get_variable(
"W_I_diag_timef", shape=[self._num_units], dtype=dtype)
"W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
w_i_diag_timet = vs.get_variable(
"W_I_diag_timet", shape=[self._num_units], dtype=dtype)
"W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
w_o_diag_timef = vs.get_variable(
"W_O_diag_timef", shape=[self._num_units], dtype=dtype)
"W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
w_o_diag_timet = vs.get_variable(
"W_O_diag_timet", shape=[self._num_units], dtype=dtype)
"W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
for freq_index in range(len(freq_inputs)):
if state_is_tuple:
name_prefix = "%s_f%02d" % (state_prefix, freq_index)
name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
c_prev_time = getattr(state, name_prefix + "_c")
m_prev_time = getattr(state, name_prefix + "_m")
else:
......@@ -773,13 +799,6 @@ class GridLSTMCell(rnn_cell.RNNCell):
input_size = input_feat.get_shape().with_rank(2)[-1].value
if input_size is None:
raise ValueError("Cannot infer input_size from static shape inference.")
num_feats = int((input_size - self._feature_size) / (
self._frequency_skip)) + 1
if num_feats != self._num_frequency_blocks:
raise ValueError(
"Invalid num_frequency_blocks, requires %d but gets %d, please check"
" the input size and filter config are correct." % (
self._num_frequency_blocks, num_feats))
if slice_offset > 0:
# Padding to the end
inputs = array_ops.pad(
......@@ -796,11 +815,55 @@ class GridLSTMCell(rnn_cell.RNNCell):
else:
inputs = input_feat
freq_inputs = []
for f in range(num_feats):
cur_input = array_ops.slice(
inputs, [0, slice_offset + f * self._frequency_skip],
[-1, self._feature_size])
freq_inputs.append(cur_input)
if not self._start_freqindex_list:
if len(self._num_frequency_blocks) != 1:
raise ValueError("Length of num_frequency_blocks"
" is not 1, but instead is %d",
len(self._num_frequency_blocks))
num_feats = int((input_size - self._feature_size) / (
self._frequency_skip)) + 1
if num_feats != self._num_frequency_blocks[0]:
raise ValueError(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
" check the input size and filter config are correct." % (
self._num_frequency_blocks[0], num_feats))
block_inputs = []
for f in range(num_feats):
cur_input = array_ops.slice(
inputs, [0, slice_offset + f * self._frequency_skip],
[-1, self._feature_size])
block_inputs.append(cur_input)
freq_inputs.append(block_inputs)
else:
if len(self._start_freqindex_list) != len(self._end_freqindex_list):
raise ValueError("Length of start and end freqindex_list"
" does not match %d %d",
len(self._start_freqindex_list),
len(self._end_freqindex_list))
if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
raise ValueError("Length of num_frequency_blocks"
" is not equal to start_freqindex_list %d %d",
len(self._num_frequency_blocks),
len(self._start_freqindex_list))
for b in range(len(self._start_freqindex_list)):
start_index = self._start_freqindex_list[b]
end_index = self._end_freqindex_list[b]
cur_size = end_index - start_index
block_feats = int((cur_size - self._feature_size) / (
self._frequency_skip)) + 1
if block_feats != self._num_frequency_blocks[b]:
raise ValueError(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
" check the input size and filter config are correct." % (
self._num_frequency_blocks[b], block_feats))
block_inputs = []
for f in range(block_feats):
cur_input = array_ops.slice(
inputs, [0, start_index + slice_offset + f *
self._frequency_skip],
[-1, self._feature_size])
block_inputs.append(cur_input)
freq_inputs.append(block_inputs)
return freq_inputs
......@@ -818,7 +881,9 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
cell_clip=None, initializer=None,
num_unit_shards=1, forget_bias=1.0,
feature_size=None, frequency_skip=None,
num_frequency_blocks=1,
num_frequency_blocks=None,
start_freqindex_list=None,
end_freqindex_list=None,
couple_input_forget_gates=False,
backward_slice_offset=0):
"""Initialize the parameters for an LSTM cell.
......@@ -842,8 +907,13 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
the LSTM spans over.
frequency_skip: (optional) int, default None, The amount the LSTM filter
is shifted by in frequency.
num_frequency_blocks: (optional) int, default 1, The total number of
frequency blocks needed to cover the whole input feature.
num_frequency_blocks: [required] A list of frequency blocks needed to
cover the whole input feature splitting defined by start_freqindex_list
and end_freqindex_list.
start_freqindex_list: [optional], list of ints, default None, The
starting frequency index for each frequency block.
end_freqindex_list: [optional], list of ints, default None. The ending
frequency index for each frequency block.
couple_input_forget_gates: (optional) bool, default False, Whether to
couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
model parameters and computation cost.
......@@ -853,19 +923,22 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
super(BidirectionalGridLSTMCell, self).__init__(
num_units, use_peepholes, share_time_frequency_weights, cell_clip,
initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
num_frequency_blocks, couple_input_forget_gates=False,
num_frequency_blocks, start_freqindex_list, end_freqindex_list,
couple_input_forget_gates=False,
state_is_tuple=True)
self._backward_slice_offset = int(backward_slice_offset)
state_names = ""
for direction in ["fwd", "bwd"]:
for freq_index in range(self._num_frequency_blocks):
name_prefix = "%s_state_f%02d" % (direction, freq_index)
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
for block_index in range(len(self._num_frequency_blocks)):
for freq_index in range(self._num_frequency_blocks[block_index]):
name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
block_index)
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
self._state_tuple_type = collections.namedtuple(
"BidirectionalGridLSTMStateTuple", state_names.strip(","))
self._state_size = self._state_tuple_type(
*([num_units, num_units] * self._num_frequency_blocks * 2))
self._output_size = 2 * num_units * self._num_frequency_blocks * 2
*([num_units, num_units] * self._total_blocks * 2))
self._output_size = 2 * num_units * self._total_blocks * 2
def __call__(self, inputs, state, scope=None):
"""Run one step of LSTM.
......@@ -893,22 +966,31 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
else:
bwd_inputs = fwd_inputs
# Reverse the blocks
bwd_inputs = bwd_inputs[::-1]
# Forward processing
with vs.variable_scope((scope or type(self).__name__) + "/fwd",
initializer=self._initializer):
fwd_m_out_lst, fwd_state_out_lst = self._compute(
fwd_inputs, state, batch_size, state_prefix="fwd_state",
state_is_tuple=True)
fwd_m_out_lst = []
fwd_state_out_lst = []
for block in range(len(fwd_inputs)):
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
fwd_inputs[block], block, state, batch_size,
state_prefix="fwd_state", state_is_tuple=True)
fwd_m_out_lst.extend(fwd_m_out_lst_current)
fwd_state_out_lst.extend(fwd_state_out_lst_current)
# Backward processing
bwd_m_out_lst = []
bwd_state_out_lst = []
with vs.variable_scope((scope or type(self).__name__) + "/bwd",
initializer=self._initializer):
bwd_m_out_lst, bwd_state_out_lst = self._compute(
bwd_inputs, state, batch_size, state_prefix="bwd_state",
state_is_tuple=True)
for block in range(len(bwd_inputs)):
# Reverse the blocks
bwd_inputs_reverse = bwd_inputs[block][::-1]
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
bwd_inputs_reverse, block, state, batch_size,
state_prefix="bwd_state", state_is_tuple=True)
bwd_m_out_lst.extend(bwd_m_out_lst_current)
bwd_state_out_lst.extend(bwd_state_out_lst_current)
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
# Outputs are always concated as it is never used separately.
m_out = array_ops.concat(1, fwd_m_out_lst + bwd_m_out_lst)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册