diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 3d9caeec5897fcd5b9e084aff496d150efee2066..919c531d184b0a95ce8b456d57465b90eee5003e 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -360,7 +360,7 @@ mixed.__doc__ = conf_helps.mixed_layer.__doc__ class RecurrentLayerInput(Layer): - def __init__(self, recurrent_name, index, parent_layers): + def __init__(self, recurrent_name, index, parent_layers, reverse): parents_len = len(parent_layers) assert parents_len <= 1 if parents_len == 0: @@ -368,6 +368,7 @@ class RecurrentLayerInput(Layer): else: self.__parents__ = parent_layers.values()[0] self.__recurrent_name__ = recurrent_name + self.__reverse__ = reverse name = self.__parents__[ index].name if index >= 0 else self.context_name() super(RecurrentLayerInput, self).__init__( @@ -380,7 +381,8 @@ class RecurrentLayerInput(Layer): model_type('recurrent_nn') RecurrentLayerGroupWithoutOutLinksBegin( name=self.__recurrent_name__, - in_links=map(lambda x: x.name, self.__parents__)) + in_links=map(lambda x: x.name, self.__parents__), + seq_reversed=self.__reverse__) return self @@ -461,7 +463,7 @@ del each_layer_name @wrap_name_default() -def recurrent_group(step, input, name=None): +def recurrent_group(step, input, reverse=False, name=None): if not isinstance(input, collections.Sequence): input = [input] @@ -471,14 +473,14 @@ def recurrent_group(step, input, name=None): RecurrentLayerInput( recurrent_name=name, index=i, - parent_layers={'recurrent_inputs': non_static_inputs}) - for i in xrange(len(non_static_inputs)) + parent_layers={'recurrent_inputs': non_static_inputs}, + reverse=reverse) for i in xrange(len(non_static_inputs)) ] extra_input = None if len(non_static_inputs) == 0: extra_input = RecurrentLayerInput( - recurrent_name=name, index=-1, parent_layers={}) + recurrent_name=name, index=-1, parent_layers={}, reverse=reverse) def __real_step__(*args): rnn_input = list(args) diff --git a/python/paddle/v2/tests/test_rnn_layer.py b/python/paddle/v2/tests/test_rnn_layer.py index 5fbbd20eb76bb9daab2bcf98c4adad989106a377..845277c01288f99f75a148ddab5895d00864f60c 100644 --- a/python/paddle/v2/tests/test_rnn_layer.py +++ b/python/paddle/v2/tests/test_rnn_layer.py @@ -42,7 +42,8 @@ class RNNTest(unittest.TestCase): def test(): data = conf_helps.data_layer(name="word", size=dict_dim) embd = conf_helps.embedding_layer(input=data, size=word_dim) - conf_helps.recurrent_group(name="rnn", step=step, input=embd) + conf_helps.recurrent_group( + name="rnn", step=step, input=embd, reverse=True) return str(parse_network(test)) @@ -60,7 +61,7 @@ class RNNTest(unittest.TestCase): name="word", type=data_type.integer_value(dict_dim)) embd = layer.embedding(input=data, size=word_dim) rnn_layer = layer.recurrent_group( - name="rnn", step=new_step, input=embd) + name="rnn", step=new_step, input=embd, reverse=True) return str(layer.parse_network(rnn_layer)) diff = difflib.unified_diff(parse_old_rnn().splitlines(1),