提交 29520d77 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #2273 from luotao1/reverse

Expose the "reverse" argument for recurrent_group in V2 API.
...@@ -360,7 +360,7 @@ mixed.__doc__ = conf_helps.mixed_layer.__doc__ ...@@ -360,7 +360,7 @@ mixed.__doc__ = conf_helps.mixed_layer.__doc__
class RecurrentLayerInput(Layer): class RecurrentLayerInput(Layer):
def __init__(self, recurrent_name, index, parent_layers): def __init__(self, recurrent_name, index, parent_layers, reverse):
parents_len = len(parent_layers) parents_len = len(parent_layers)
assert parents_len <= 1 assert parents_len <= 1
if parents_len == 0: if parents_len == 0:
...@@ -368,6 +368,7 @@ class RecurrentLayerInput(Layer): ...@@ -368,6 +368,7 @@ class RecurrentLayerInput(Layer):
else: else:
self.__parents__ = parent_layers.values()[0] self.__parents__ = parent_layers.values()[0]
self.__recurrent_name__ = recurrent_name self.__recurrent_name__ = recurrent_name
self.__reverse__ = reverse
name = self.__parents__[ name = self.__parents__[
index].name if index >= 0 else self.context_name() index].name if index >= 0 else self.context_name()
super(RecurrentLayerInput, self).__init__( super(RecurrentLayerInput, self).__init__(
...@@ -380,7 +381,8 @@ class RecurrentLayerInput(Layer): ...@@ -380,7 +381,8 @@ class RecurrentLayerInput(Layer):
model_type('recurrent_nn') model_type('recurrent_nn')
RecurrentLayerGroupWithoutOutLinksBegin( RecurrentLayerGroupWithoutOutLinksBegin(
name=self.__recurrent_name__, name=self.__recurrent_name__,
in_links=map(lambda x: x.name, self.__parents__)) in_links=map(lambda x: x.name, self.__parents__),
seq_reversed=self.__reverse__)
return self return self
...@@ -461,7 +463,7 @@ del each_layer_name ...@@ -461,7 +463,7 @@ del each_layer_name
@wrap_name_default() @wrap_name_default()
def recurrent_group(step, input, name=None): def recurrent_group(step, input, reverse=False, name=None):
if not isinstance(input, collections.Sequence): if not isinstance(input, collections.Sequence):
input = [input] input = [input]
...@@ -471,14 +473,14 @@ def recurrent_group(step, input, name=None): ...@@ -471,14 +473,14 @@ def recurrent_group(step, input, name=None):
RecurrentLayerInput( RecurrentLayerInput(
recurrent_name=name, recurrent_name=name,
index=i, index=i,
parent_layers={'recurrent_inputs': non_static_inputs}) parent_layers={'recurrent_inputs': non_static_inputs},
for i in xrange(len(non_static_inputs)) reverse=reverse) for i in xrange(len(non_static_inputs))
] ]
extra_input = None extra_input = None
if len(non_static_inputs) == 0: if len(non_static_inputs) == 0:
extra_input = RecurrentLayerInput( extra_input = RecurrentLayerInput(
recurrent_name=name, index=-1, parent_layers={}) recurrent_name=name, index=-1, parent_layers={}, reverse=reverse)
def __real_step__(*args): def __real_step__(*args):
rnn_input = list(args) rnn_input = list(args)
......
...@@ -42,7 +42,8 @@ class RNNTest(unittest.TestCase): ...@@ -42,7 +42,8 @@ class RNNTest(unittest.TestCase):
def test(): def test():
data = conf_helps.data_layer(name="word", size=dict_dim) data = conf_helps.data_layer(name="word", size=dict_dim)
embd = conf_helps.embedding_layer(input=data, size=word_dim) embd = conf_helps.embedding_layer(input=data, size=word_dim)
conf_helps.recurrent_group(name="rnn", step=step, input=embd) conf_helps.recurrent_group(
name="rnn", step=step, input=embd, reverse=True)
return str(parse_network(test)) return str(parse_network(test))
...@@ -60,7 +61,7 @@ class RNNTest(unittest.TestCase): ...@@ -60,7 +61,7 @@ class RNNTest(unittest.TestCase):
name="word", type=data_type.integer_value(dict_dim)) name="word", type=data_type.integer_value(dict_dim))
embd = layer.embedding(input=data, size=word_dim) embd = layer.embedding(input=data, size=word_dim)
rnn_layer = layer.recurrent_group( rnn_layer = layer.recurrent_group(
name="rnn", step=new_step, input=embd) name="rnn", step=new_step, input=embd, reverse=True)
return str(layer.parse_network(rnn_layer)) return str(layer.parse_network(rnn_layer))
diff = difflib.unified_diff(parse_old_rnn().splitlines(1), diff = difflib.unified_diff(parse_old_rnn().splitlines(1),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册