提交 8c594376 编写于 作者: L Luo Tao

change is_train to is_generating

上级 ffbf00a0
...@@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None): ...@@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None):
:return: A ScalingProjection object :return: A ScalingProjection object
:rtype: ScalingProjection :rtype: ScalingProjection
""" """
proj = ScalingProjection(input_layer_name=input.name, proj = ScalingProjection(input_layer_name=input.name, **param_attr.attr)
**param_attr.attr)
proj.origin = input proj.origin = input
return proj return proj
...@@ -2788,7 +2787,7 @@ def recurrent_group(step, ...@@ -2788,7 +2787,7 @@ def recurrent_group(step,
reverse=False, reverse=False,
name=None, name=None,
targetInlink=None, targetInlink=None,
is_train=True): is_generating=False):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -2853,11 +2852,11 @@ def recurrent_group(step, ...@@ -2853,11 +2852,11 @@ def recurrent_group(step,
:type targetInlink: LayerOutput|SubsequenceInput :type targetInlink: LayerOutput|SubsequenceInput
:param is_train: recurrent_group is used for training (True) or generating (False). :param is_generating: If is generating, none of input type should be LayerOutput;
If is training, one of the input type must be LayerOutput; else, else, for training or testing, one of the input type must
none of input type should be LayerOutput. be LayerOutput.
: type is_train: bool : type is_generating: bool
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
...@@ -2928,7 +2927,7 @@ def recurrent_group(step, ...@@ -2928,7 +2927,7 @@ def recurrent_group(step,
mix += identity_projection(mem) mix += identity_projection(mem)
in_args.append(mem) in_args.append(mem)
assert (is_train == has_LayerOutput) assert (is_generating != has_LayerOutput)
layer_outs = step(*in_args) layer_outs = step(*in_args)
...@@ -3225,7 +3224,7 @@ def beam_search(step, ...@@ -3225,7 +3224,7 @@ def beam_search(step,
input=real_input, input=real_input,
reverse=False, reverse=False,
name=name, name=name,
is_train=False) is_generating=True)
return tmp return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册