diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 78aa0778f8d1dca9fae82f0411be5a00e636cbc9..f6e8819e0f86d8f6c36a72de2b3759dce28dd0f5 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3529,12 +3529,7 @@ def SubsequenceInput(input): @wrap_name_default("recurrent_group") -def recurrent_group(step, - input, - reverse=False, - name=None, - targetInlink=None, - is_generating=False): +def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): """ Recurrent layer group is an extremely flexible recurrent unit in PaddlePaddle. As long as the user defines the calculation done within a @@ -3600,21 +3595,12 @@ def recurrent_group(step, :type targetInlink: LayerOutput|SubsequenceInput - :param is_generating: If is generating, none of input type should be LayerOutput; - else, for training or testing, one of the input type must - be LayerOutput. - - :type is_generating: bool - :return: LayerOutput object. :rtype: LayerOutput """ model_type('recurrent_nn') - def is_single_input(x): - return isinstance(x, LayerOutput) or isinstance(x, StaticInput) - - if is_single_input(input): + if isinstance(input, LayerOutput) or isinstance(input, StaticInput): input = [input] assert isinstance(input, collections.Sequence) @@ -3628,13 +3614,8 @@ def recurrent_group(step, in_links=map(lambda x: x.name, in_links), seq_reversed=reverse) in_args = [] - has_LayerOutput = False for each_input in input: - assert is_single_input(each_input) - if isinstance(each_input, LayerOutput): - in_args.append(each_input) - has_LayerOutput = True - else: # StaticInput + if isinstance(each_input, StaticInput): # StaticInput mem_name = "__%s_memory__" % each_input.input.name mem = memory( name=None, @@ -3642,8 +3623,8 @@ def recurrent_group(step, boot_layer=each_input.input) mem.set_input(mem) in_args.append(mem) - - assert (is_generating != has_LayerOutput) + else: + in_args.append(each_input) layer_outs = step(*in_args) @@ -3869,6 +3850,7 @@ def beam_search(step, :type step: callable :param input: Input data for the recurrent unit, which should include the previously generated words as a GeneratedInput object. + In beam_search, none of the input's type should be LayerOutput. :type input: list :param bos_id: Index of the start symbol in the dictionary. The start symbol is a special token for NLP task, which indicates the @@ -3910,15 +3892,18 @@ def beam_search(step, real_input = [] for i, each_input in enumerate(input): - assert isinstance(each_input, StaticInput) or isinstance( - each_input, BaseGeneratedInput) + assert not isinstance(each_input, LayerOutput), ( + "in beam_search, " + "none of the input should has a type of LayerOutput.") if isinstance(each_input, BaseGeneratedInput): - assert generated_input_index == -1 + assert generated_input_index == -1, ("recurrent_group accepts " + "only one GeneratedInput.") generated_input_index = i + else: real_input.append(each_input) - assert generated_input_index != -1 + assert generated_input_index != -1, "No GeneratedInput is given." gipt = input[generated_input_index] @@ -3942,14 +3927,8 @@ def beam_search(step, eos_layer(input=predict, eos_id=eos_id, name=eos_name) return predict - tmp = recurrent_group( - step=__real_step__, - input=real_input, - reverse=False, - name=name, - is_generating=True) - - return tmp + return recurrent_group( + step=__real_step__, input=real_input, reverse=False, name=name) def __cost_input__(input, label, weight=None): diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 810bea913ec79b2df0eb63ed5a4fd411549ff2e9..396073236c347865be95a1a5a6641d7047c3b183 100755 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -15,6 +15,7 @@ """ # from activations import * +import pdb from activations import LinearActivation, ReluActivation, SoftmaxActivation, \ IdentityActivation, TanhActivation, SequenceSoftmaxActivation from attrs import ExtraAttr @@ -614,6 +615,7 @@ def simple_lstm(input, @wrap_name_default('lstm_unit') def lstmemory_unit(input, + out_memory=None, memory_boot=None, name=None, size=None, @@ -694,7 +696,11 @@ def lstmemory_unit(input, if size is None: assert input.size % 4 == 0 size = input.size / 4 - out_mem = memory(name=name, size=size) + if out_memory is None: + out_mem = memory(name=name, size=size) + else: + out_mem = out_memory + state_mem = memory( name="%s_state" % name, size=size, boot_layer=memory_boot)