From ad4ab5ac811d90dd2bbb661ad34ba5ee3aa510a1 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 26 Feb 2017 16:29:02 +0800 Subject: [PATCH] remove step_input in recurrent_group step_input --- .../paddle/trainer_config_helpers/layers.py | 8 ++- python/paddle/v2/layer.py | 61 +++++++++++++++---- python/paddle/v2/tests/test_layer.py | 13 ++-- 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 00aef80691f..4e200517fc4 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3042,7 +3042,8 @@ def recurrent_group(step, reverse=False, name=None, targetInlink=None, - is_generating=False): + is_generating=False, + in_args_converter=None): """ Recurrent layer group is an extremely flexible recurrent unit in PaddlePaddle. As long as the user defines the calculation done within a @@ -3185,7 +3186,10 @@ def recurrent_group(step, assert (is_generating != has_LayerOutput) - layer_outs = step(*in_args) + if in_args_converter is None: + layer_outs = step(*in_args) + else: + layer_outs = step(*in_args_converter(*in_args)).to_proto(dict()) if isinstance(layer_outs, LayerOutput): layer_outs = [layer_outs] diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 5ecc96c6856..44c7661b246 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -73,8 +73,6 @@ from paddle.trainer_config_helpers.config_parser_utils import \ parse_network_config as __parse__ from paddle.trainer_config_helpers.default_decorators import wrap_name_default -import activation -import attr import data_type __all__ = [ @@ -101,11 +99,10 @@ def parse_network(*outputs): class Layer(object): - def __init__(self, name, parent_layers, step_input=None): + def __init__(self, name, parent_layers): assert isinstance(parent_layers, dict) assert isinstance(name, basestring) self.name = name - self.step_input = step_input self.__parent_layers__ = parent_layers def to_proto(self, context): @@ -121,12 +118,13 @@ class Layer(object): else: v1_layer = map(lambda x: x.to_proto(context=context), self.__parent_layers__[layer_name]) - if layer_name == "input" and self.step_input is not None: - v1_layer.insert(0, self.step_input) kwargs[layer_name] = v1_layer + if self.name is None: + return self.to_proto_impl(**kwargs) + # memory may have the same name with some layer - if isinstance(self, MemoryV2): + if isinstance(self, MemoryV2) or isinstance(self, LayerOutputV2): return self.to_proto_impl(**kwargs) if self.name not in context: @@ -144,7 +142,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): wrapper = None class V2LayerImpl(Layer): - def __init__(self, name=None, step_input=None, **kwargs): + def __init__(self, name=None, **kwargs): parent_layers = dict() other_kwargs = dict() for pname in parent_names: @@ -155,7 +153,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): if key not in parent_names: other_kwargs[key] = kwargs[key] - super(V2LayerImpl, self).__init__(name, parent_layers, step_input) + super(V2LayerImpl, self).__init__(name, parent_layers) self.__other_kwargs__ = other_kwargs if wrapper is not None: @@ -214,6 +212,48 @@ class MemoryV2(Layer): return conf_helps.memory(name=self.name, size=self.size, **args) +class LayerOutputV2(Layer): + def __init__(self, layer_output): + assert isinstance(layer_output, conf_helps.LayerOutput) + self.layer_output = layer_output + super(LayerOutputV2, self).__init__( + name=layer_output.name, parent_layers=dict()) + + def to_proto_impl(self): + return self.layer_output + + +class RecurrentGroupV2(Layer): + def __init__(self, name, **kwargs): + self.__parent_names__ = ['input'] + other_kwargs = dict() + parent_layers = dict() + for pname in self.__parent_names__: + if kwargs.has_key(pname): + parent_layers[pname] = kwargs[pname] + for key in kwargs.keys(): + if key not in self.__parent_names__: + other_kwargs[key] = kwargs[key] + self.__kwargs__ = other_kwargs + + super(RecurrentGroupV2, self).__init__( + name=name, parent_layers=parent_layers) + + def to_proto_impl(self, **kwargs): + def in_args_converter(in_args): + if not isinstance(in_args, collections.Sequence): + in_args = [in_args] + return [LayerOutputV2(input) for input in in_args] + + args = dict() + for each in kwargs: + args[each] = kwargs[each] + for each in self.__kwargs__: + args[each] = self.__kwargs__[each] + return conf_helps.recurrent_group( + name=self.name, in_args_converter=in_args_converter, **args) + + data = DataLayerV2 fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input']) max_id = __convert_to_v2__( @@ -234,8 +274,7 @@ embedding = __convert_to_v2__( 'embedding_layer', name_prefix='embedding', parent_names=['input']) last_seq = __convert_to_v2__( 'last_seq', name_prefix='last_seq', parent_names=['input']) -recurrent_group = __convert_to_v2__( - 'recurrent_group', name_prefix='recurrent_layer', parent_names=['input']) +recurrent_group = RecurrentGroupV2 memory = MemoryV2 cross_entropy_with_selfnorm_cost = __convert_to_v2__( diff --git a/python/paddle/v2/tests/test_layer.py b/python/paddle/v2/tests/test_layer.py index 73d769a3582..04c0fc7cb0b 100644 --- a/python/paddle/v2/tests/test_layer.py +++ b/python/paddle/v2/tests/test_layer.py @@ -63,7 +63,7 @@ class RNNTest(unittest.TestCase): word_dim = 8 hidden_dim = 8 - def test_old_rnn(): + def parse_old_rnn(): def step(y): mem = conf_helps.memory(name="rnn_state", size=hidden_dim) out = conf_helps.fc_layer( @@ -81,16 +81,15 @@ class RNNTest(unittest.TestCase): return str(parse_network(test)) - def test_new_rnn(): + def parse_new_rnn(): def new_step(y): mem = layer.memory(name="rnn_state", size=hidden_dim) - out = layer.fc(input=[mem], - step_input=y, + out = layer.fc(input=[y, mem], size=hidden_dim, act=activation.Tanh(), bias_attr=True, name="rnn_state") - return out.to_proto(dict()) + return out data1 = layer.data( name="word", type=data_type.integer_value(dict_dim)) @@ -99,8 +98,8 @@ class RNNTest(unittest.TestCase): name="rnn", step=new_step, input=embd) return str(layer.parse_network(rnn_layer)) - diff = difflib.unified_diff(test_old_rnn().splitlines(1), - test_new_rnn().splitlines(1)) + diff = difflib.unified_diff(parse_old_rnn().splitlines(1), + parse_new_rnn().splitlines(1)) print ''.join(diff) -- GitLab