You need to sign in or sign up before continuing.
提交 0752b3b7 编写于 作者: L Luo Tao

add layer check for recurrent_group

上级 35c175dd
...@@ -2754,7 +2754,12 @@ class SubsequenceInput(object): ...@@ -2754,7 +2754,12 @@ class SubsequenceInput(object):
@wrap_name_default("recurrent_group") @wrap_name_default("recurrent_group")
def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): def recurrent_group(step,
input,
reverse=False,
name=None,
targetInlink=None,
is_train=True):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -2819,6 +2824,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2819,6 +2824,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
:type targetInlink: LayerOutput|SubsequenceInput :type targetInlink: LayerOutput|SubsequenceInput
:param is_train: recurrent_group is used for training (True) or generating (False).
If is training, one of the input type must be LayerOutput; else,
none of input type should be LayerOutput.
: type is_train: bool
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -2866,6 +2877,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2866,6 +2877,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
seq_reversed=reverse, seq_reversed=reverse,
target_inlinkname=targetInlinkName) target_inlinkname=targetInlinkName)
in_args = [] in_args = []
has_LayerOutput = True
for each_input in input: for each_input in input:
assert is_single_input(each_input) assert is_single_input(each_input)
if isinstance(each_input, LayerOutput): if isinstance(each_input, LayerOutput):
...@@ -2873,6 +2885,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2873,6 +2885,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
elif isinstance(each_input, SubsequenceInput): elif isinstance(each_input, SubsequenceInput):
in_args.append(each_input.input) in_args.append(each_input.input)
else: else:
has_LayerOutput = False
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=mem_name, name=mem_name,
...@@ -2886,6 +2899,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2886,6 +2899,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
mix += identity_projection(mem) mix += identity_projection(mem)
in_args.append(mem) in_args.append(mem)
assert (is_train == has_LayerOutput)
layer_outs = step(*in_args) layer_outs = step(*in_args)
if isinstance(layer_outs, LayerOutput): if isinstance(layer_outs, LayerOutput):
...@@ -3177,7 +3192,11 @@ def beam_search(step, ...@@ -3177,7 +3192,11 @@ def beam_search(step,
return predict return predict
tmp = recurrent_group( tmp = recurrent_group(
step=__real_step__, input=real_input, reverse=False, name=name) step=__real_step__,
input=real_input,
reverse=False,
name=name,
is_train=False)
return tmp return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册