提交 0ed51ce2 编写于 作者: C caoying03

fix bug of type check of inputs to recurrent_group.

上级 45ce1649
...@@ -3529,12 +3529,7 @@ def SubsequenceInput(input): ...@@ -3529,12 +3529,7 @@ def SubsequenceInput(input):
@wrap_name_default("recurrent_group") @wrap_name_default("recurrent_group")
def recurrent_group(step, def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
input,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -3600,21 +3595,12 @@ def recurrent_group(step, ...@@ -3600,21 +3595,12 @@ def recurrent_group(step,
:type targetInlink: LayerOutput|SubsequenceInput :type targetInlink: LayerOutput|SubsequenceInput
:param is_generating: If is generating, none of input type should be LayerOutput;
else, for training or testing, one of the input type must
be LayerOutput.
:type is_generating: bool
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
model_type('recurrent_nn') model_type('recurrent_nn')
def is_single_input(x): if isinstance(input, LayerOutput) or isinstance(input, StaticInput):
return isinstance(x, LayerOutput) or isinstance(x, StaticInput)
if is_single_input(input):
input = [input] input = [input]
assert isinstance(input, collections.Sequence) assert isinstance(input, collections.Sequence)
...@@ -3628,13 +3614,8 @@ def recurrent_group(step, ...@@ -3628,13 +3614,8 @@ def recurrent_group(step,
in_links=map(lambda x: x.name, in_links), in_links=map(lambda x: x.name, in_links),
seq_reversed=reverse) seq_reversed=reverse)
in_args = [] in_args = []
has_LayerOutput = False
for each_input in input: for each_input in input:
assert is_single_input(each_input) if isinstance(each_input, StaticInput): # StaticInput
if isinstance(each_input, LayerOutput):
in_args.append(each_input)
has_LayerOutput = True
else: # StaticInput
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=None, name=None,
...@@ -3642,8 +3623,8 @@ def recurrent_group(step, ...@@ -3642,8 +3623,8 @@ def recurrent_group(step,
boot_layer=each_input.input) boot_layer=each_input.input)
mem.set_input(mem) mem.set_input(mem)
in_args.append(mem) in_args.append(mem)
else:
assert (is_generating != has_LayerOutput) in_args.append(each_input)
layer_outs = step(*in_args) layer_outs = step(*in_args)
...@@ -3869,6 +3850,7 @@ def beam_search(step, ...@@ -3869,6 +3850,7 @@ def beam_search(step,
:type step: callable :type step: callable
:param input: Input data for the recurrent unit, which should include the :param input: Input data for the recurrent unit, which should include the
previously generated words as a GeneratedInput object. previously generated words as a GeneratedInput object.
In beam_search, none of the input's type should be LayerOutput.
:type input: list :type input: list
:param bos_id: Index of the start symbol in the dictionary. The start symbol :param bos_id: Index of the start symbol in the dictionary. The start symbol
is a special token for NLP task, which indicates the is a special token for NLP task, which indicates the
...@@ -3910,15 +3892,18 @@ def beam_search(step, ...@@ -3910,15 +3892,18 @@ def beam_search(step,
real_input = [] real_input = []
for i, each_input in enumerate(input): for i, each_input in enumerate(input):
assert isinstance(each_input, StaticInput) or isinstance( assert not isinstance(each_input, LayerOutput), (
each_input, BaseGeneratedInput) "in beam_search, "
"none of the input should has a type of LayerOutput.")
if isinstance(each_input, BaseGeneratedInput): if isinstance(each_input, BaseGeneratedInput):
assert generated_input_index == -1 assert generated_input_index == -1, ("recurrent_group accepts "
"only one GeneratedInput.")
generated_input_index = i generated_input_index = i
else: else:
real_input.append(each_input) real_input.append(each_input)
assert generated_input_index != -1 assert generated_input_index != -1, "No GeneratedInput is given."
gipt = input[generated_input_index] gipt = input[generated_input_index]
...@@ -3942,14 +3927,8 @@ def beam_search(step, ...@@ -3942,14 +3927,8 @@ def beam_search(step,
eos_layer(input=predict, eos_id=eos_id, name=eos_name) eos_layer(input=predict, eos_id=eos_id, name=eos_name)
return predict return predict
tmp = recurrent_group( return recurrent_group(
step=__real_step__, step=__real_step__, input=real_input, reverse=False, name=name)
input=real_input,
reverse=False,
name=name,
is_generating=True)
return tmp
def __cost_input__(input, label, weight=None): def __cost_input__(input, label, weight=None):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" """
# from activations import * # from activations import *
import pdb
from activations import LinearActivation, ReluActivation, SoftmaxActivation, \ from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
IdentityActivation, TanhActivation, SequenceSoftmaxActivation IdentityActivation, TanhActivation, SequenceSoftmaxActivation
from attrs import ExtraAttr from attrs import ExtraAttr
...@@ -614,6 +615,7 @@ def simple_lstm(input, ...@@ -614,6 +615,7 @@ def simple_lstm(input,
@wrap_name_default('lstm_unit') @wrap_name_default('lstm_unit')
def lstmemory_unit(input, def lstmemory_unit(input,
out_memory=None,
memory_boot=None, memory_boot=None,
name=None, name=None,
size=None, size=None,
...@@ -694,7 +696,11 @@ def lstmemory_unit(input, ...@@ -694,7 +696,11 @@ def lstmemory_unit(input,
if size is None: if size is None:
assert input.size % 4 == 0 assert input.size % 4 == 0
size = input.size / 4 size = input.size / 4
if out_memory is None:
out_mem = memory(name=name, size=size) out_mem = memory(name=name, size=size)
else:
out_mem = out_memory
state_mem = memory( state_mem = memory(
name="%s_state" % name, size=size, boot_layer=memory_boot) name="%s_state" % name, size=size, boot_layer=memory_boot)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册