提交 0752b3b7 编写于 作者: L Luo Tao

add layer check for recurrent_group

上级 35c175dd
......@@ -2754,7 +2754,12 @@ class SubsequenceInput(object):
@wrap_name_default("recurrent_group")
def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
def recurrent_group(step,
input,
reverse=False,
name=None,
targetInlink=None,
is_train=True):
"""
Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a
......@@ -2819,6 +2824,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
:type targetInlink: LayerOutput|SubsequenceInput
:param is_train: recurrent_group is used for training (True) or generating (False).
If is training, one of the input type must be LayerOutput; else,
none of input type should be LayerOutput.
: type is_train: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
......@@ -2866,6 +2877,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
seq_reversed=reverse,
target_inlinkname=targetInlinkName)
in_args = []
has_LayerOutput = True
for each_input in input:
assert is_single_input(each_input)
if isinstance(each_input, LayerOutput):
......@@ -2873,6 +2885,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
elif isinstance(each_input, SubsequenceInput):
in_args.append(each_input.input)
else:
has_LayerOutput = False
mem_name = "__%s_memory__" % each_input.input.name
mem = memory(
name=mem_name,
......@@ -2886,6 +2899,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
mix += identity_projection(mem)
in_args.append(mem)
assert (is_train == has_LayerOutput)
layer_outs = step(*in_args)
if isinstance(layer_outs, LayerOutput):
......@@ -3177,7 +3192,11 @@ def beam_search(step,
return predict
tmp = recurrent_group(
step=__real_step__, input=real_input, reverse=False, name=name)
step=__real_step__,
input=real_input,
reverse=False,
name=name,
is_train=False)
return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册