diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 796121a64136ee3f31b2ed09b761c6a83cdbe625..952b1f097133c68643cb201fa30452e9e08afcd8 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2754,7 +2754,12 @@ class SubsequenceInput(object): @wrap_name_default("recurrent_group") -def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): +def recurrent_group(step, + input, + reverse=False, + name=None, + targetInlink=None, + is_train=True): """ Recurrent layer group is an extremely flexible recurrent unit in PaddlePaddle. As long as the user defines the calculation done within a @@ -2819,6 +2824,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): :type targetInlink: LayerOutput|SubsequenceInput + :param is_train: recurrent_group is used for training (True) or generating (False). + If is training, one of the input type must be LayerOutput; else, + none of input type should be LayerOutput. + + : type is_train: bool + :return: LayerOutput object. :rtype: LayerOutput """ @@ -2866,6 +2877,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): seq_reversed=reverse, target_inlinkname=targetInlinkName) in_args = [] + has_LayerOutput = True for each_input in input: assert is_single_input(each_input) if isinstance(each_input, LayerOutput): @@ -2873,6 +2885,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): elif isinstance(each_input, SubsequenceInput): in_args.append(each_input.input) else: + has_LayerOutput = False mem_name = "__%s_memory__" % each_input.input.name mem = memory( name=mem_name, @@ -2886,6 +2899,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): mix += identity_projection(mem) in_args.append(mem) + assert (is_train == has_LayerOutput) + layer_outs = step(*in_args) if isinstance(layer_outs, LayerOutput): @@ -3177,7 +3192,11 @@ def beam_search(step, return predict tmp = recurrent_group( - step=__real_step__, input=real_input, reverse=False, name=name) + step=__real_step__, + input=real_input, + reverse=False, + name=name, + is_train=False) return tmp