提交 d2ff3e49 编写于 作者: L LCY-Seso 提交者: GitHub

Merge pull request #458 from luotao1/group

add layer check for recurrent_group
...@@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None): ...@@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None):
:return: A ScalingProjection object :return: A ScalingProjection object
:rtype: ScalingProjection :rtype: ScalingProjection
""" """
proj = ScalingProjection(input_layer_name=input.name, proj = ScalingProjection(input_layer_name=input.name, **param_attr.attr)
**param_attr.attr)
proj.origin = input proj.origin = input
return proj return proj
...@@ -2783,7 +2782,12 @@ class SubsequenceInput(object): ...@@ -2783,7 +2782,12 @@ class SubsequenceInput(object):
@wrap_name_default("recurrent_group") @wrap_name_default("recurrent_group")
def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): def recurrent_group(step,
input,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -2848,6 +2852,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2848,6 +2852,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
:type targetInlink: LayerOutput|SubsequenceInput :type targetInlink: LayerOutput|SubsequenceInput
:param is_generating: If is generating, none of input type should be LayerOutput;
else, for training or testing, one of the input type must
be LayerOutput.
: type is_generating: bool
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -2895,6 +2905,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2895,6 +2905,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
seq_reversed=reverse, seq_reversed=reverse,
target_inlinkname=targetInlinkName) target_inlinkname=targetInlinkName)
in_args = [] in_args = []
has_LayerOutput = True
for each_input in input: for each_input in input:
assert is_single_input(each_input) assert is_single_input(each_input)
if isinstance(each_input, LayerOutput): if isinstance(each_input, LayerOutput):
...@@ -2902,6 +2913,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2902,6 +2913,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
elif isinstance(each_input, SubsequenceInput): elif isinstance(each_input, SubsequenceInput):
in_args.append(each_input.input) in_args.append(each_input.input)
else: else:
has_LayerOutput = False
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=mem_name, name=mem_name,
...@@ -2915,6 +2927,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): ...@@ -2915,6 +2927,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
mix += identity_projection(mem) mix += identity_projection(mem)
in_args.append(mem) in_args.append(mem)
assert (is_generating != has_LayerOutput)
layer_outs = step(*in_args) layer_outs = step(*in_args)
if isinstance(layer_outs, LayerOutput): if isinstance(layer_outs, LayerOutput):
...@@ -3206,7 +3220,11 @@ def beam_search(step, ...@@ -3206,7 +3220,11 @@ def beam_search(step,
return predict return predict
tmp = recurrent_group( tmp = recurrent_group(
step=__real_step__, input=real_input, reverse=False, name=name) step=__real_step__,
input=real_input,
reverse=False,
name=name,
is_generating=True)
return tmp return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部