提交 0ed51ce2 编写于 作者: C caoying03

fix bug of type check of inputs to recurrent_group.

上级 45ce1649
......@@ -3529,12 +3529,7 @@ def SubsequenceInput(input):
@wrap_name_default("recurrent_group")
def recurrent_group(step,
input,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
"""
Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a
......@@ -3600,21 +3595,12 @@ def recurrent_group(step,
:type targetInlink: LayerOutput|SubsequenceInput
:param is_generating: If is generating, none of input type should be LayerOutput;
else, for training or testing, one of the input type must
be LayerOutput.
:type is_generating: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
model_type('recurrent_nn')
def is_single_input(x):
return isinstance(x, LayerOutput) or isinstance(x, StaticInput)
if is_single_input(input):
if isinstance(input, LayerOutput) or isinstance(input, StaticInput):
input = [input]
assert isinstance(input, collections.Sequence)
......@@ -3628,13 +3614,8 @@ def recurrent_group(step,
in_links=map(lambda x: x.name, in_links),
seq_reversed=reverse)
in_args = []
has_LayerOutput = False
for each_input in input:
assert is_single_input(each_input)
if isinstance(each_input, LayerOutput):
in_args.append(each_input)
has_LayerOutput = True
else: # StaticInput
if isinstance(each_input, StaticInput): # StaticInput
mem_name = "__%s_memory__" % each_input.input.name
mem = memory(
name=None,
......@@ -3642,8 +3623,8 @@ def recurrent_group(step,
boot_layer=each_input.input)
mem.set_input(mem)
in_args.append(mem)
assert (is_generating != has_LayerOutput)
else:
in_args.append(each_input)
layer_outs = step(*in_args)
......@@ -3869,6 +3850,7 @@ def beam_search(step,
:type step: callable
:param input: Input data for the recurrent unit, which should include the
previously generated words as a GeneratedInput object.
In beam_search, none of the input's type should be LayerOutput.
:type input: list
:param bos_id: Index of the start symbol in the dictionary. The start symbol
is a special token for NLP task, which indicates the
......@@ -3910,15 +3892,18 @@ def beam_search(step,
real_input = []
for i, each_input in enumerate(input):
assert isinstance(each_input, StaticInput) or isinstance(
each_input, BaseGeneratedInput)
assert not isinstance(each_input, LayerOutput), (
"in beam_search, "
"none of the input should has a type of LayerOutput.")
if isinstance(each_input, BaseGeneratedInput):
assert generated_input_index == -1
assert generated_input_index == -1, ("recurrent_group accepts "
"only one GeneratedInput.")
generated_input_index = i
else:
real_input.append(each_input)
assert generated_input_index != -1
assert generated_input_index != -1, "No GeneratedInput is given."
gipt = input[generated_input_index]
......@@ -3942,14 +3927,8 @@ def beam_search(step,
eos_layer(input=predict, eos_id=eos_id, name=eos_name)
return predict
tmp = recurrent_group(
step=__real_step__,
input=real_input,
reverse=False,
name=name,
is_generating=True)
return tmp
return recurrent_group(
step=__real_step__, input=real_input, reverse=False, name=name)
def __cost_input__(input, label, weight=None):
......
......@@ -15,6 +15,7 @@
"""
# from activations import *
import pdb
from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
IdentityActivation, TanhActivation, SequenceSoftmaxActivation
from attrs import ExtraAttr
......@@ -614,6 +615,7 @@ def simple_lstm(input,
@wrap_name_default('lstm_unit')
def lstmemory_unit(input,
out_memory=None,
memory_boot=None,
name=None,
size=None,
......@@ -694,7 +696,11 @@ def lstmemory_unit(input,
if size is None:
assert input.size % 4 == 0
size = input.size / 4
if out_memory is None:
out_mem = memory(name=name, size=size)
else:
out_mem = out_memory
state_mem = memory(
name="%s_state" % name, size=size, boot_layer=memory_boot)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册