From 8c59437612280738f59df859d9f512b318153d05 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 15 Nov 2016 12:35:36 +0800 Subject: [PATCH] change is_train to is_generating --- python/paddle/trainer_config_helpers/layers.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index f0757b9ce21..7cd290023ab 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None): :return: A ScalingProjection object :rtype: ScalingProjection """ - proj = ScalingProjection(input_layer_name=input.name, - **param_attr.attr) + proj = ScalingProjection(input_layer_name=input.name, **param_attr.attr) proj.origin = input return proj @@ -2788,7 +2787,7 @@ def recurrent_group(step, reverse=False, name=None, targetInlink=None, - is_train=True): + is_generating=False): """ Recurrent layer group is an extremely flexible recurrent unit in PaddlePaddle. As long as the user defines the calculation done within a @@ -2853,11 +2852,11 @@ def recurrent_group(step, :type targetInlink: LayerOutput|SubsequenceInput - :param is_train: recurrent_group is used for training (True) or generating (False). - If is training, one of the input type must be LayerOutput; else, - none of input type should be LayerOutput. + :param is_generating: If is generating, none of input type should be LayerOutput; + else, for training or testing, one of the input type must + be LayerOutput. - : type is_train: bool + : type is_generating: bool :return: LayerOutput object. :rtype: LayerOutput @@ -2928,7 +2927,7 @@ def recurrent_group(step, mix += identity_projection(mem) in_args.append(mem) - assert (is_train == has_LayerOutput) + assert (is_generating != has_LayerOutput) layer_outs = step(*in_args) @@ -3225,7 +3224,7 @@ def beam_search(step, input=real_input, reverse=False, name=name, - is_train=False) + is_generating=True) return tmp -- GitLab