From 0752b3b710a5f29fba7f7d5e595aa9da3a76fdda Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 14 Nov 2016 14:08:29 +0800 Subject: [PATCH] add layer check for recurrent_group --- .../paddle/trainer_config_helpers/layers.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 796121a6413..952b1f09713 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2754,7 +2754,12 @@ class SubsequenceInput(object): @wrap_name_default("recurrent_group") -def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): +def recurrent_group(step, + input, + reverse=False, + name=None, + targetInlink=None, + is_train=True): """ Recurrent layer group is an extremely flexible recurrent unit in PaddlePaddle. As long as the user defines the calculation done within a @@ -2819,6 +2824,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): :type targetInlink: LayerOutput|SubsequenceInput + :param is_train: recurrent_group is used for training (True) or generating (False). + If is training, one of the input type must be LayerOutput; else, + none of input type should be LayerOutput. + + : type is_train: bool + :return: LayerOutput object. :rtype: LayerOutput """ @@ -2866,6 +2877,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): seq_reversed=reverse, target_inlinkname=targetInlinkName) in_args = [] + has_LayerOutput = True for each_input in input: assert is_single_input(each_input) if isinstance(each_input, LayerOutput): @@ -2873,6 +2885,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): elif isinstance(each_input, SubsequenceInput): in_args.append(each_input.input) else: + has_LayerOutput = False mem_name = "__%s_memory__" % each_input.input.name mem = memory( name=mem_name, @@ -2886,6 +2899,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): mix += identity_projection(mem) in_args.append(mem) + assert (is_train == has_LayerOutput) + layer_outs = step(*in_args) if isinstance(layer_outs, LayerOutput): @@ -3177,7 +3192,11 @@ def beam_search(step, return predict tmp = recurrent_group( - step=__real_step__, input=real_input, reverse=False, name=name) + step=__real_step__, + input=real_input, + reverse=False, + name=name, + is_train=False) return tmp -- GitLab