提交 0ee31a03 编写于 作者: L Luo Tao

Expose the "reverse" argument for recurrent_group in V2 API

上级 21be601b
......@@ -360,7 +360,7 @@ mixed.__doc__ = conf_helps.mixed_layer.__doc__
class RecurrentLayerInput(Layer):
def __init__(self, recurrent_name, index, parent_layers):
def __init__(self, recurrent_name, index, parent_layers, reverse):
parents_len = len(parent_layers)
assert parents_len <= 1
if parents_len == 0:
......@@ -368,6 +368,7 @@ class RecurrentLayerInput(Layer):
else:
self.__parents__ = parent_layers.values()[0]
self.__recurrent_name__ = recurrent_name
self.__reverse__ = reverse
name = self.__parents__[
index].name if index >= 0 else self.context_name()
super(RecurrentLayerInput, self).__init__(
......@@ -380,7 +381,8 @@ class RecurrentLayerInput(Layer):
model_type('recurrent_nn')
RecurrentLayerGroupWithoutOutLinksBegin(
name=self.__recurrent_name__,
in_links=map(lambda x: x.name, self.__parents__))
in_links=map(lambda x: x.name, self.__parents__),
seq_reversed=self.__reverse__)
return self
......@@ -461,7 +463,7 @@ del each_layer_name
@wrap_name_default()
def recurrent_group(step, input, name=None):
def recurrent_group(step, input, reverse=False, name=None):
if not isinstance(input, collections.Sequence):
input = [input]
......@@ -471,14 +473,14 @@ def recurrent_group(step, input, name=None):
RecurrentLayerInput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_inputs': non_static_inputs})
for i in xrange(len(non_static_inputs))
parent_layers={'recurrent_inputs': non_static_inputs},
reverse=reverse) for i in xrange(len(non_static_inputs))
]
extra_input = None
if len(non_static_inputs) == 0:
extra_input = RecurrentLayerInput(
recurrent_name=name, index=-1, parent_layers={})
recurrent_name=name, index=-1, parent_layers={}, reverse=reverse)
def __real_step__(*args):
rnn_input = list(args)
......
......@@ -42,7 +42,8 @@ class RNNTest(unittest.TestCase):
def test():
data = conf_helps.data_layer(name="word", size=dict_dim)
embd = conf_helps.embedding_layer(input=data, size=word_dim)
conf_helps.recurrent_group(name="rnn", step=step, input=embd)
conf_helps.recurrent_group(
name="rnn", step=step, input=embd, reverse=True)
return str(parse_network(test))
......@@ -60,7 +61,7 @@ class RNNTest(unittest.TestCase):
name="word", type=data_type.integer_value(dict_dim))
embd = layer.embedding(input=data, size=word_dim)
rnn_layer = layer.recurrent_group(
name="rnn", step=new_step, input=embd)
name="rnn", step=new_step, input=embd, reverse=True)
return str(layer.parse_network(rnn_layer))
diff = difflib.unified_diff(parse_old_rnn().splitlines(1),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册