提交 d2ff3e49 编写于 作者: L LCY-Seso 提交者: GitHub

Merge pull request #458 from luotao1/group

add layer check for recurrent_group
......@@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None):
:return: A ScalingProjection object
:rtype: ScalingProjection
"""
proj = ScalingProjection(input_layer_name=input.name,
**param_attr.attr)
proj = ScalingProjection(input_layer_name=input.name, **param_attr.attr)
proj.origin = input
return proj
......@@ -2783,7 +2782,12 @@ class SubsequenceInput(object):
@wrap_name_default("recurrent_group")
def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
def recurrent_group(step,
input,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
"""
Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a
......@@ -2848,6 +2852,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
:type targetInlink: LayerOutput|SubsequenceInput
:param is_generating: If is generating, none of input type should be LayerOutput;
else, for training or testing, one of the input type must
be LayerOutput.
: type is_generating: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
......@@ -2895,6 +2905,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
seq_reversed=reverse,
target_inlinkname=targetInlinkName)
in_args = []
has_LayerOutput = True
for each_input in input:
assert is_single_input(each_input)
if isinstance(each_input, LayerOutput):
......@@ -2902,6 +2913,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
elif isinstance(each_input, SubsequenceInput):
in_args.append(each_input.input)
else:
has_LayerOutput = False
mem_name = "__%s_memory__" % each_input.input.name
mem = memory(
name=mem_name,
......@@ -2915,6 +2927,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
mix += identity_projection(mem)
in_args.append(mem)
assert (is_generating != has_LayerOutput)
layer_outs = step(*in_args)
if isinstance(layer_outs, LayerOutput):
......@@ -3206,7 +3220,11 @@ def beam_search(step,
return predict
tmp = recurrent_group(
step=__real_step__, input=real_input, reverse=False, name=name)
step=__real_step__,
input=real_input,
reverse=False,
name=name,
is_generating=True)
return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册