diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b5e10ef81009a00e76b0c4147b404ba0aaba72b3..7cd290023aba10eba63d307b96aac57734d043c0 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None): :return: A ScalingProjection object :rtype: ScalingProjection """ - proj = ScalingProjection(input_layer_name=input.name, - **param_attr.attr) + proj = ScalingProjection(input_layer_name=input.name, **param_attr.attr) proj.origin = input return proj @@ -2783,7 +2782,12 @@ class SubsequenceInput(object): @wrap_name_default("recurrent_group") -def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): +def recurrent_group(step, + input, + reverse=False, + name=None, + targetInlink=None, + is_generating=False): """ Recurrent layer group is an extremely flexible recurrent unit in PaddlePaddle. As long as the user defines the calculation done within a @@ -2848,6 +2852,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): :type targetInlink: LayerOutput|SubsequenceInput + :param is_generating: If is generating, none of input type should be LayerOutput; + else, for training or testing, one of the input type must + be LayerOutput. + + : type is_generating: bool + :return: LayerOutput object. :rtype: LayerOutput """ @@ -2895,6 +2905,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): seq_reversed=reverse, target_inlinkname=targetInlinkName) in_args = [] + has_LayerOutput = True for each_input in input: assert is_single_input(each_input) if isinstance(each_input, LayerOutput): @@ -2902,6 +2913,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): elif isinstance(each_input, SubsequenceInput): in_args.append(each_input.input) else: + has_LayerOutput = False mem_name = "__%s_memory__" % each_input.input.name mem = memory( name=mem_name, @@ -2915,6 +2927,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): mix += identity_projection(mem) in_args.append(mem) + assert (is_generating != has_LayerOutput) + layer_outs = step(*in_args) if isinstance(layer_outs, LayerOutput): @@ -3206,7 +3220,11 @@ def beam_search(step, return predict tmp = recurrent_group( - step=__real_step__, input=real_input, reverse=False, name=name) + step=__real_step__, + input=real_input, + reverse=False, + name=name, + is_generating=True) return tmp