Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
9e601230
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9e601230
编写于
10月 20, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
10月 20, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Frequency Block processing to GridLSTMCell.
Change: 136742352
上级
876c4aa5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
196 addition
and
67 deletion
+196
-67
tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+52
-5
tensorflow/contrib/rnn/python/ops/rnn_cell.py
tensorflow/contrib/rnn/python/ops/rnn_cell.py
+144
-62
未找到文件。
tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
浏览文件 @
9e601230
...
...
@@ -103,7 +103,7 @@ class RNNCellTest(tf.test.TestCase):
cell
=
tf
.
contrib
.
rnn
.
GridLSTMCell
(
num_units
=
num_units
,
feature_size
=
feature_size
,
frequency_skip
=
frequency_skip
,
forget_bias
=
1.0
,
num_frequency_blocks
=
num_shifts
,
num_frequency_blocks
=
[
num_shifts
]
,
couple_input_forget_gates
=
True
,
state_is_tuple
=
True
)
inputs
=
tf
.
constant
(
np
.
array
([[
1.
,
1.
,
1.
,
1.
],
...
...
@@ -129,7 +129,54 @@ class RNNCellTest(tf.test.TestCase):
self
.
assertTrue
(
float
(
np
.
linalg
.
norm
((
res
[
0
][
0
,
:]
-
res
[
0
][
i
,
:])))
>
1e-6
)
self
.
assertTrue
(
float
(
np
.
linalg
.
norm
(
(
res
[
1
].
state_f00_c
[
0
,
:]
-
res
[
1
].
state_f00_c
[
i
,
:])))
(
res
[
1
].
state_f00_b00_c
[
0
,
:]
-
res
[
1
].
state_f00_b00_c
[
i
,
:])))
>
1e-6
)
def
testGridLSTMCellWithFrequencyBlocks
(
self
):
with
self
.
test_session
()
as
sess
:
num_units
=
8
batch_size
=
3
input_size
=
4
feature_size
=
2
frequency_skip
=
1
num_frequency_blocks
=
[
1
,
1
]
total_blocks
=
num_frequency_blocks
[
0
]
+
num_frequency_blocks
[
1
]
start_freqindex_list
=
[
0
,
2
]
end_freqindex_list
=
[
2
,
4
]
with
tf
.
variable_scope
(
"root"
,
initializer
=
tf
.
constant_initializer
(
0.5
)):
cell
=
tf
.
contrib
.
rnn
.
GridLSTMCell
(
num_units
=
num_units
,
feature_size
=
feature_size
,
frequency_skip
=
frequency_skip
,
forget_bias
=
1.0
,
num_frequency_blocks
=
num_frequency_blocks
,
start_freqindex_list
=
start_freqindex_list
,
end_freqindex_list
=
end_freqindex_list
,
couple_input_forget_gates
=
True
,
state_is_tuple
=
True
)
inputs
=
tf
.
constant
(
np
.
array
([[
1.
,
1.
,
1.
,
1.
],
[
2.
,
2.
,
2.
,
2.
],
[
3.
,
3.
,
3.
,
3.
]],
dtype
=
np
.
float32
),
dtype
=
tf
.
float32
)
state_value
=
tf
.
constant
(
0.1
*
np
.
ones
((
batch_size
,
num_units
),
dtype
=
np
.
float32
),
dtype
=
tf
.
float32
)
init_state
=
cell
.
state_tuple_type
(
*
([
state_value
,
state_value
]
*
total_blocks
))
output
,
state
=
cell
(
inputs
,
init_state
)
sess
.
run
([
tf
.
initialize_all_variables
()])
res
=
sess
.
run
([
output
,
state
])
self
.
assertEqual
(
len
(
res
),
2
)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
self
.
assertEqual
(
res
[
0
].
shape
,
(
batch_size
,
num_units
*
total_blocks
*
2
))
for
ss
in
res
[
1
]:
self
.
assertEqual
(
ss
.
shape
,
(
batch_size
,
num_units
))
# Different inputs so different outputs and states
for
i
in
range
(
1
,
batch_size
):
self
.
assertTrue
(
float
(
np
.
linalg
.
norm
((
res
[
0
][
0
,
:]
-
res
[
0
][
i
,
:])))
>
1e-6
)
self
.
assertTrue
(
float
(
np
.
linalg
.
norm
(
(
res
[
1
].
state_f00_b00_c
[
0
,
:]
-
res
[
1
].
state_f00_b00_c
[
i
,
:])))
>
1e-6
)
def
testGridLstmCellWithCoupledInputForgetGates
(
self
):
...
...
@@ -162,7 +209,7 @@ class RNNCellTest(tf.test.TestCase):
cell
=
tf
.
contrib
.
rnn
.
GridLSTMCell
(
num_units
=
num_units
,
feature_size
=
feature_size
,
frequency_skip
=
frequency_skip
,
forget_bias
=
1.0
,
num_frequency_blocks
=
num_shifts
,
num_frequency_blocks
=
[
num_shifts
]
,
couple_input_forget_gates
=
True
,
state_is_tuple
=
state_is_tuple
)
inputs
=
tf
.
constant
(
np
.
array
([[
1.
,
1.
,
1.
,
1.
],
...
...
@@ -238,7 +285,7 @@ class RNNCellTest(tf.test.TestCase):
num_units
=
num_units
,
feature_size
=
feature_size
,
share_time_frequency_weights
=
True
,
frequency_skip
=
frequency_skip
,
forget_bias
=
1.0
,
num_frequency_blocks
=
num_shifts
)
num_frequency_blocks
=
[
num_shifts
]
)
inputs
=
tf
.
constant
(
np
.
array
([[
1.0
,
1.1
,
1.2
,
1.3
],
[
2.0
,
2.1
,
2.2
,
2.3
],
[
3.0
,
3.1
,
3.2
,
3.3
]],
...
...
@@ -305,7 +352,7 @@ class RNNCellTest(tf.test.TestCase):
num_units
=
num_units
,
feature_size
=
feature_size
,
share_time_frequency_weights
=
True
,
frequency_skip
=
frequency_skip
,
forget_bias
=
1.0
,
num_frequency_blocks
=
num_shifts
,
num_frequency_blocks
=
[
num_shifts
]
,
backward_slice_offset
=
1
)
inputs
=
tf
.
constant
(
np
.
array
([[
1.0
,
1.1
,
1.2
,
1.3
],
[
2.0
,
2.1
,
2.2
,
2.3
],
...
...
tensorflow/contrib/rnn/python/ops/rnn_cell.py
浏览文件 @
9e601230
...
...
@@ -431,7 +431,9 @@ class GridLSTMCell(rnn_cell.RNNCell):
cell_clip
=
None
,
initializer
=
None
,
num_unit_shards
=
1
,
forget_bias
=
1.0
,
feature_size
=
None
,
frequency_skip
=
None
,
num_frequency_blocks
=
1
,
num_frequency_blocks
=
None
,
start_freqindex_list
=
None
,
end_freqindex_list
=
None
,
couple_input_forget_gates
=
False
,
state_is_tuple
=
False
):
"""Initialize the parameters for an LSTM cell.
...
...
@@ -455,14 +457,21 @@ class GridLSTMCell(rnn_cell.RNNCell):
the LSTM spans over.
frequency_skip: (optional) int, default None, The amount the LSTM filter
is shifted by in frequency.
num_frequency_blocks: (optional) int, default 1, The total number of
frequency blocks needed to cover the whole input feature.
num_frequency_blocks: [required] A list of frequency blocks needed to
cover the whole input feature splitting defined by start_freqindex_list
and end_freqindex_list.
start_freqindex_list: [optional], list of ints, default None, The
starting frequency index for each frequency block.
end_freqindex_list: [optional], list of ints, default None. The ending
frequency index for each frequency block.
couple_input_forget_gates: (optional) bool, default False, Whether to
couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
model parameters and computation cost.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
Raises:
ValueError: if the num_frequency_blocks list is not specified
"""
if
not
state_is_tuple
:
logging
.
warn
(
"%s: Using a concatenated state is slower and will soon be "
...
...
@@ -478,20 +487,29 @@ class GridLSTMCell(rnn_cell.RNNCell):
self
.
_forget_bias
=
forget_bias
self
.
_feature_size
=
feature_size
self
.
_frequency_skip
=
frequency_skip
self
.
_num_frequency_blocks
=
int
(
num_frequency_blocks
)
self
.
_start_freqindex_list
=
start_freqindex_list
self
.
_end_freqindex_list
=
end_freqindex_list
self
.
_num_frequency_blocks
=
num_frequency_blocks
self
.
_total_blocks
=
0
if
self
.
_num_frequency_blocks
is
None
:
raise
ValueError
(
"Must specify num_frequency_blocks"
)
for
block_index
in
range
(
len
(
self
.
_num_frequency_blocks
)):
self
.
_total_blocks
+=
int
(
self
.
_num_frequency_blocks
[
block_index
])
if
state_is_tuple
:
state_names
=
""
for
freq_index
in
range
(
self
.
_num_frequency_blocks
):
name_prefix
=
"state_f%02d"
%
freq_index
state_names
+=
(
"%s_c, %s_m,"
%
(
name_prefix
,
name_prefix
))
for
block_index
in
range
(
len
(
self
.
_num_frequency_blocks
)):
for
freq_index
in
range
(
self
.
_num_frequency_blocks
[
block_index
]):
name_prefix
=
"state_f%02d_b%02d"
%
(
freq_index
,
block_index
)
state_names
+=
(
"%s_c, %s_m,"
%
(
name_prefix
,
name_prefix
))
self
.
_state_tuple_type
=
collections
.
namedtuple
(
"GridLSTMStateTuple"
,
state_names
.
strip
(
","
))
self
.
_state_size
=
self
.
_state_tuple_type
(
*
([
num_units
,
num_units
]
*
self
.
_
num_frequency
_blocks
))
*
([
num_units
,
num_units
]
*
self
.
_
total
_blocks
))
else
:
self
.
_state_tuple_type
=
None
self
.
_state_size
=
num_units
*
self
.
_
num_frequency
_blocks
*
2
self
.
_output_size
=
num_units
*
self
.
_
num_frequency
_blocks
*
2
self
.
_state_size
=
num_units
*
self
.
_
total
_blocks
*
2
self
.
_output_size
=
num_units
*
self
.
_
total
_blocks
*
2
@
property
def
output_size
(
self
):
...
...
@@ -530,8 +548,14 @@ class GridLSTMCell(rnn_cell.RNNCell):
freq_inputs
=
self
.
_make_tf_features
(
inputs
)
with
vs
.
variable_scope
(
scope
or
type
(
self
).
__name__
,
initializer
=
self
.
_initializer
):
# "GridLSTMCell"
m_out_lst
,
state_out_lst
=
self
.
_compute
(
freq_inputs
,
state
,
batch_size
,
state_is_tuple
=
self
.
_state_is_tuple
)
m_out_lst
=
[]
state_out_lst
=
[]
for
block
in
range
(
len
(
freq_inputs
)):
m_out_lst_current
,
state_out_lst_current
=
self
.
_compute
(
freq_inputs
[
block
],
block
,
state
,
batch_size
,
state_is_tuple
=
self
.
_state_is_tuple
)
m_out_lst
.
extend
(
m_out_lst_current
)
state_out_lst
.
extend
(
state_out_lst_current
)
if
self
.
_state_is_tuple
:
state_out
=
self
.
_state_tuple_type
(
*
state_out_lst
)
else
:
...
...
@@ -539,12 +563,14 @@ class GridLSTMCell(rnn_cell.RNNCell):
m_out
=
array_ops
.
concat
(
1
,
m_out_lst
)
return
m_out
,
state_out
def
_compute
(
self
,
freq_inputs
,
state
,
batch_size
,
state_prefix
=
"state"
,
def
_compute
(
self
,
freq_inputs
,
block
,
state
,
batch_size
,
state_prefix
=
"state"
,
state_is_tuple
=
True
):
"""Run the actual computation of one step LSTM.
Args:
freq_inputs: list of Tensors, 2D, [batch, feature_size].
block: int, current frequency block index to process.
state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
the flag state_is_tuple.
batch_size: int32, batch size.
...
...
@@ -566,57 +592,57 @@ class GridLSTMCell(rnn_cell.RNNCell):
actual_input_size
=
freq_inputs
[
0
].
get_shape
().
as_list
()[
1
]
concat_w_f
=
_get_concat_variable
(
"W_f
"
,
[
actual_input_size
+
2
*
self
.
_num_units
,
num_gates
*
self
.
_num_units
],
"W_f
_%d"
%
block
,
[
actual_input_size
+
2
*
self
.
_num_units
,
num_gates
*
self
.
_num_units
],
dtype
,
self
.
_num_unit_shards
)
b_f
=
vs
.
get_variable
(
"B_f
"
,
shape
=
[
num_gates
*
self
.
_num_units
],
"B_f
_%d"
%
block
,
shape
=
[
num_gates
*
self
.
_num_units
],
initializer
=
init_ops
.
zeros_initializer
,
dtype
=
dtype
)
if
not
self
.
_share_time_frequency_weights
:
concat_w_t
=
_get_concat_variable
(
"W_t
"
,
[
actual_input_size
+
2
*
self
.
_num_units
,
num_gates
*
self
.
_num_units
],
"W_t
_%d"
%
block
,
[
actual_input_size
+
2
*
self
.
_num_units
,
num_gates
*
self
.
_num_units
],
dtype
,
self
.
_num_unit_shards
)
b_t
=
vs
.
get_variable
(
"B_t
"
,
shape
=
[
num_gates
*
self
.
_num_units
],
"B_t
_%d"
%
block
,
shape
=
[
num_gates
*
self
.
_num_units
],
initializer
=
init_ops
.
zeros_initializer
,
dtype
=
dtype
)
if
self
.
_use_peepholes
:
# Diagonal connections
if
not
self
.
_couple_input_forget_gates
:
w_f_diag_freqf
=
vs
.
get_variable
(
"W_F_diag_freqf
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_F_diag_freqf
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_f_diag_freqt
=
vs
.
get_variable
(
"W_F_diag_freqt
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_F_diag_freqt
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_i_diag_freqf
=
vs
.
get_variable
(
"W_I_diag_freqf
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_I_diag_freqf
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_i_diag_freqt
=
vs
.
get_variable
(
"W_I_diag_freqt
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_I_diag_freqt
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_o_diag_freqf
=
vs
.
get_variable
(
"W_O_diag_freqf
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_O_diag_freqf
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_o_diag_freqt
=
vs
.
get_variable
(
"W_O_diag_freqt
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_O_diag_freqt
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
if
not
self
.
_share_time_frequency_weights
:
if
not
self
.
_couple_input_forget_gates
:
w_f_diag_timef
=
vs
.
get_variable
(
"W_F_diag_timef
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_F_diag_timef
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_f_diag_timet
=
vs
.
get_variable
(
"W_F_diag_timet
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_F_diag_timet
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_i_diag_timef
=
vs
.
get_variable
(
"W_I_diag_timef
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_I_diag_timef
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_i_diag_timet
=
vs
.
get_variable
(
"W_I_diag_timet
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_I_diag_timet
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_o_diag_timef
=
vs
.
get_variable
(
"W_O_diag_timef
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_O_diag_timef
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
w_o_diag_timet
=
vs
.
get_variable
(
"W_O_diag_timet
"
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
"W_O_diag_timet
_%d"
%
block
,
shape
=
[
self
.
_num_units
],
dtype
=
dtype
)
# initialize the first freq state to be zero
m_prev_freq
=
array_ops
.
zeros
([
batch_size
,
self
.
_num_units
],
dtype
)
c_prev_freq
=
array_ops
.
zeros
([
batch_size
,
self
.
_num_units
],
dtype
)
for
freq_index
in
range
(
len
(
freq_inputs
)):
if
state_is_tuple
:
name_prefix
=
"%s_f%02d
"
%
(
state_prefix
,
freq_index
)
name_prefix
=
"%s_f%02d
_b%02d"
%
(
state_prefix
,
freq_index
,
block
)
c_prev_time
=
getattr
(
state
,
name_prefix
+
"_c"
)
m_prev_time
=
getattr
(
state
,
name_prefix
+
"_m"
)
else
:
...
...
@@ -773,13 +799,6 @@ class GridLSTMCell(rnn_cell.RNNCell):
input_size
=
input_feat
.
get_shape
().
with_rank
(
2
)[
-
1
].
value
if
input_size
is
None
:
raise
ValueError
(
"Cannot infer input_size from static shape inference."
)
num_feats
=
int
((
input_size
-
self
.
_feature_size
)
/
(
self
.
_frequency_skip
))
+
1
if
num_feats
!=
self
.
_num_frequency_blocks
:
raise
ValueError
(
"Invalid num_frequency_blocks, requires %d but gets %d, please check"
" the input size and filter config are correct."
%
(
self
.
_num_frequency_blocks
,
num_feats
))
if
slice_offset
>
0
:
# Padding to the end
inputs
=
array_ops
.
pad
(
...
...
@@ -796,11 +815,55 @@ class GridLSTMCell(rnn_cell.RNNCell):
else
:
inputs
=
input_feat
freq_inputs
=
[]
for
f
in
range
(
num_feats
):
cur_input
=
array_ops
.
slice
(
inputs
,
[
0
,
slice_offset
+
f
*
self
.
_frequency_skip
],
[
-
1
,
self
.
_feature_size
])
freq_inputs
.
append
(
cur_input
)
if
not
self
.
_start_freqindex_list
:
if
len
(
self
.
_num_frequency_blocks
)
!=
1
:
raise
ValueError
(
"Length of num_frequency_blocks"
" is not 1, but instead is %d"
,
len
(
self
.
_num_frequency_blocks
))
num_feats
=
int
((
input_size
-
self
.
_feature_size
)
/
(
self
.
_frequency_skip
))
+
1
if
num_feats
!=
self
.
_num_frequency_blocks
[
0
]:
raise
ValueError
(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
" check the input size and filter config are correct."
%
(
self
.
_num_frequency_blocks
[
0
],
num_feats
))
block_inputs
=
[]
for
f
in
range
(
num_feats
):
cur_input
=
array_ops
.
slice
(
inputs
,
[
0
,
slice_offset
+
f
*
self
.
_frequency_skip
],
[
-
1
,
self
.
_feature_size
])
block_inputs
.
append
(
cur_input
)
freq_inputs
.
append
(
block_inputs
)
else
:
if
len
(
self
.
_start_freqindex_list
)
!=
len
(
self
.
_end_freqindex_list
):
raise
ValueError
(
"Length of start and end freqindex_list"
" does not match %d %d"
,
len
(
self
.
_start_freqindex_list
),
len
(
self
.
_end_freqindex_list
))
if
len
(
self
.
_num_frequency_blocks
)
!=
len
(
self
.
_start_freqindex_list
):
raise
ValueError
(
"Length of num_frequency_blocks"
" is not equal to start_freqindex_list %d %d"
,
len
(
self
.
_num_frequency_blocks
),
len
(
self
.
_start_freqindex_list
))
for
b
in
range
(
len
(
self
.
_start_freqindex_list
)):
start_index
=
self
.
_start_freqindex_list
[
b
]
end_index
=
self
.
_end_freqindex_list
[
b
]
cur_size
=
end_index
-
start_index
block_feats
=
int
((
cur_size
-
self
.
_feature_size
)
/
(
self
.
_frequency_skip
))
+
1
if
block_feats
!=
self
.
_num_frequency_blocks
[
b
]:
raise
ValueError
(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
" check the input size and filter config are correct."
%
(
self
.
_num_frequency_blocks
[
b
],
block_feats
))
block_inputs
=
[]
for
f
in
range
(
block_feats
):
cur_input
=
array_ops
.
slice
(
inputs
,
[
0
,
start_index
+
slice_offset
+
f
*
self
.
_frequency_skip
],
[
-
1
,
self
.
_feature_size
])
block_inputs
.
append
(
cur_input
)
freq_inputs
.
append
(
block_inputs
)
return
freq_inputs
...
...
@@ -818,7 +881,9 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
cell_clip
=
None
,
initializer
=
None
,
num_unit_shards
=
1
,
forget_bias
=
1.0
,
feature_size
=
None
,
frequency_skip
=
None
,
num_frequency_blocks
=
1
,
num_frequency_blocks
=
None
,
start_freqindex_list
=
None
,
end_freqindex_list
=
None
,
couple_input_forget_gates
=
False
,
backward_slice_offset
=
0
):
"""Initialize the parameters for an LSTM cell.
...
...
@@ -842,8 +907,13 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
the LSTM spans over.
frequency_skip: (optional) int, default None, The amount the LSTM filter
is shifted by in frequency.
num_frequency_blocks: (optional) int, default 1, The total number of
frequency blocks needed to cover the whole input feature.
num_frequency_blocks: [required] A list of frequency blocks needed to
cover the whole input feature splitting defined by start_freqindex_list
and end_freqindex_list.
start_freqindex_list: [optional], list of ints, default None, The
starting frequency index for each frequency block.
end_freqindex_list: [optional], list of ints, default None. The ending
frequency index for each frequency block.
couple_input_forget_gates: (optional) bool, default False, Whether to
couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
model parameters and computation cost.
...
...
@@ -853,19 +923,22 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
super
(
BidirectionalGridLSTMCell
,
self
).
__init__
(
num_units
,
use_peepholes
,
share_time_frequency_weights
,
cell_clip
,
initializer
,
num_unit_shards
,
forget_bias
,
feature_size
,
frequency_skip
,
num_frequency_blocks
,
couple_input_forget_gates
=
False
,
num_frequency_blocks
,
start_freqindex_list
,
end_freqindex_list
,
couple_input_forget_gates
=
False
,
state_is_tuple
=
True
)
self
.
_backward_slice_offset
=
int
(
backward_slice_offset
)
state_names
=
""
for
direction
in
[
"fwd"
,
"bwd"
]:
for
freq_index
in
range
(
self
.
_num_frequency_blocks
):
name_prefix
=
"%s_state_f%02d"
%
(
direction
,
freq_index
)
state_names
+=
(
"%s_c, %s_m,"
%
(
name_prefix
,
name_prefix
))
for
block_index
in
range
(
len
(
self
.
_num_frequency_blocks
)):
for
freq_index
in
range
(
self
.
_num_frequency_blocks
[
block_index
]):
name_prefix
=
"%s_state_f%02d_b%02d"
%
(
direction
,
freq_index
,
block_index
)
state_names
+=
(
"%s_c, %s_m,"
%
(
name_prefix
,
name_prefix
))
self
.
_state_tuple_type
=
collections
.
namedtuple
(
"BidirectionalGridLSTMStateTuple"
,
state_names
.
strip
(
","
))
self
.
_state_size
=
self
.
_state_tuple_type
(
*
([
num_units
,
num_units
]
*
self
.
_
num_frequency
_blocks
*
2
))
self
.
_output_size
=
2
*
num_units
*
self
.
_
num_frequency
_blocks
*
2
*
([
num_units
,
num_units
]
*
self
.
_
total
_blocks
*
2
))
self
.
_output_size
=
2
*
num_units
*
self
.
_
total
_blocks
*
2
def
__call__
(
self
,
inputs
,
state
,
scope
=
None
):
"""Run one step of LSTM.
...
...
@@ -893,22 +966,31 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
bwd_inputs
=
self
.
_make_tf_features
(
inputs
,
self
.
_backward_slice_offset
)
else
:
bwd_inputs
=
fwd_inputs
# Reverse the blocks
bwd_inputs
=
bwd_inputs
[::
-
1
]
# Forward processing
with
vs
.
variable_scope
((
scope
or
type
(
self
).
__name__
)
+
"/fwd"
,
initializer
=
self
.
_initializer
):
fwd_m_out_lst
,
fwd_state_out_lst
=
self
.
_compute
(
fwd_inputs
,
state
,
batch_size
,
state_prefix
=
"fwd_state"
,
state_is_tuple
=
True
)
fwd_m_out_lst
=
[]
fwd_state_out_lst
=
[]
for
block
in
range
(
len
(
fwd_inputs
)):
fwd_m_out_lst_current
,
fwd_state_out_lst_current
=
self
.
_compute
(
fwd_inputs
[
block
],
block
,
state
,
batch_size
,
state_prefix
=
"fwd_state"
,
state_is_tuple
=
True
)
fwd_m_out_lst
.
extend
(
fwd_m_out_lst_current
)
fwd_state_out_lst
.
extend
(
fwd_state_out_lst_current
)
# Backward processing
bwd_m_out_lst
=
[]
bwd_state_out_lst
=
[]
with
vs
.
variable_scope
((
scope
or
type
(
self
).
__name__
)
+
"/bwd"
,
initializer
=
self
.
_initializer
):
bwd_m_out_lst
,
bwd_state_out_lst
=
self
.
_compute
(
bwd_inputs
,
state
,
batch_size
,
state_prefix
=
"bwd_state"
,
state_is_tuple
=
True
)
for
block
in
range
(
len
(
bwd_inputs
)):
# Reverse the blocks
bwd_inputs_reverse
=
bwd_inputs
[
block
][::
-
1
]
bwd_m_out_lst_current
,
bwd_state_out_lst_current
=
self
.
_compute
(
bwd_inputs_reverse
,
block
,
state
,
batch_size
,
state_prefix
=
"bwd_state"
,
state_is_tuple
=
True
)
bwd_m_out_lst
.
extend
(
bwd_m_out_lst_current
)
bwd_state_out_lst
.
extend
(
bwd_state_out_lst_current
)
state_out
=
self
.
_state_tuple_type
(
*
(
fwd_state_out_lst
+
bwd_state_out_lst
))
# Outputs are always concated as it is never used separately.
m_out
=
array_ops
.
concat
(
1
,
fwd_m_out_lst
+
bwd_m_out_lst
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录