提交 ef61288f 编写于 作者: X xuwei06

Clean-up recurrent group related python code

No longer need to specify target_inlinks or is_seq.
上级 17994e38
...@@ -59,7 +59,8 @@ def outer_step(subseq, seq, nonseq, encoding): ...@@ -59,7 +59,8 @@ def outer_step(subseq, seq, nonseq, encoding):
return out return out
decoder = recurrent_group( decoder = recurrent_group(
step=inner_step, name='inner', input=[subseq, seq, nonseq]) step=inner_step, name='inner',
input=[subseq, StaticInput(seq), nonseq])
last = last_seq(name="outer_rnn_state", input=decoder) last = last_seq(name="outer_rnn_state", input=decoder)
context = simple_attention( context = simple_attention(
encoded_sequence=encoding, encoded_proj=encoding, decoder_state=last) encoded_sequence=encoding, encoded_proj=encoding, decoder_state=last)
...@@ -69,7 +70,7 @@ def outer_step(subseq, seq, nonseq, encoding): ...@@ -69,7 +70,7 @@ def outer_step(subseq, seq, nonseq, encoding):
out = recurrent_group( out = recurrent_group(
name="outer", name="outer",
step=outer_step, step=outer_step,
input=[data1, data2, label, StaticInput(encoding)]) input=[data1, data2, StaticInput(label), StaticInput(encoding)])
rep = last_seq(input=out) rep = last_seq(input=out)
prob = fc_layer( prob = fc_layer(
......
...@@ -328,16 +328,12 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name, ...@@ -328,16 +328,12 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name,
SubModelBegin(name) SubModelBegin(name)
g_current_submodel.is_recurrent_layer_group = True g_current_submodel.is_recurrent_layer_group = True
g_current_submodel.reversed = seq_reversed g_current_submodel.reversed = seq_reversed
g_current_submodel.target_inlinkid = -1
in_links_count = 0 in_links_count = 0
for linkid, link in enumerate(in_links): for linkid, link in enumerate(in_links):
if isinstance(link, basestring): if isinstance(link, basestring):
name = link name = link
else: else:
name = link.link_name name = link.link_name
# assign target_inlinkid according to target_inlinkname
if target_inlinkname == name:
g_current_submodel.target_inlinkid = linkid
in_links_count += 1 in_links_count += 1
layer_name = MakeLayerNameInParentSubmodel(name) layer_name = MakeLayerNameInParentSubmodel(name)
...@@ -373,8 +369,7 @@ def RecurrentLayerGroupBegin(name, ...@@ -373,8 +369,7 @@ def RecurrentLayerGroupBegin(name,
generator=None, generator=None,
target_inlinkname="", target_inlinkname="",
seq_reversed=False): seq_reversed=False):
RecurrentLayerGroupWithoutOutLinksBegin(name, in_links, seq_reversed, RecurrentLayerGroupWithoutOutLinksBegin(name, in_links, seq_reversed)
target_inlinkname)
for link in out_links: for link in out_links:
RecurrentLayerGroupSetOutLink(link) RecurrentLayerGroupSetOutLink(link)
...@@ -2309,7 +2304,6 @@ def Memory(name, ...@@ -2309,7 +2304,6 @@ def Memory(name,
if name is not None: if name is not None:
memory.layer_name = MakeLayerNameInSubmodel(name) memory.layer_name = MakeLayerNameInSubmodel(name)
memory.link_name = MakeLayerNameInSubmodel(agent_name) memory.link_name = MakeLayerNameInSubmodel(agent_name)
memory.is_sequence = is_sequence
options = sum((boot_layer is not None, bool(boot_bias), options = sum((boot_layer is not None, bool(boot_bias),
boot_with_const_id is not None)) boot_with_const_id is not None))
config_assert( config_assert(
......
...@@ -311,18 +311,6 @@ class LayerOutput(object): ...@@ -311,18 +311,6 @@ class LayerOutput(object):
self.outputs = outputs self.outputs = outputs
self.reverse = reverse self.reverse = reverse
def __repr__(self):
"""
Disable __repr__ for debug reason. Will be implemented when release
"""
assert False, "this method should not be invoked"
def __str__(self):
"""
Disable __str__ for debug reason. Will be implemented when release
"""
assert False, "this method should not be invoked"
def set_input(self, input): def set_input(self, input):
""" """
Set the input for a memory layer. Can only be used for memory layer Set the input for a memory layer. Can only be used for memory layer
...@@ -2944,7 +2932,7 @@ def memory(name, ...@@ -2944,7 +2932,7 @@ def memory(name,
:param memory_name: the name of the memory. :param memory_name: the name of the memory.
It is ignored when name is provided. It is ignored when name is provided.
:type memory_name: basestring :type memory_name: basestring
:param is_seq: is sequence for boot_layer :param is_seq: DEPRECATED. is sequence for boot_layer
:type is_seq: bool :type is_seq: bool
:param boot_layer: boot layer of memory. :param boot_layer: boot layer of memory.
:type boot_layer: LayerOutput|None :type boot_layer: LayerOutput|None
...@@ -2971,7 +2959,6 @@ def memory(name, ...@@ -2971,7 +2959,6 @@ def memory(name,
memory_name = Memory( memory_name = Memory(
name, name,
size, size,
is_sequence=is_seq,
boot_layer=boot_layer.name if boot_layer is not None else None, boot_layer=boot_layer.name if boot_layer is not None else None,
boot_bias=boot_bias, boot_bias=boot_bias,
boot_bias_active_type=boot_bias_active_type.name, boot_bias_active_type=boot_bias_active_type.name,
...@@ -3318,15 +3305,16 @@ class StaticInput(object): ...@@ -3318,15 +3305,16 @@ class StaticInput(object):
""" """
StaticInput is only used in recurrent_group which defines a read-only memory StaticInput is only used in recurrent_group which defines a read-only memory
that can be a sequence or non-sequence. that can be a sequence or non-sequence.
:param size: DEPRECATED
:param is_seq: DEPRECATED
""" """
def __init__(self, input, is_seq=False, size=None): def __init__(self, input, is_seq=False, size=None):
assert isinstance(input, LayerOutput) assert isinstance(input, LayerOutput)
self.input = input self.input = input
self.is_seq = is_seq assert input.size is not None
assert input.size is not None or size is not None
if size is not None: if size is not None:
input.size = size assert input.size == size
def SubsequenceInput(input): def SubsequenceInput(input):
...@@ -3452,15 +3440,10 @@ def recurrent_group(step, ...@@ -3452,15 +3440,10 @@ def recurrent_group(step,
else: # StaticInput else: # StaticInput
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=mem_name, name=None,
is_seq=each_input.is_seq,
size=each_input.input.size, size=each_input.input.size,
boot_layer=each_input.input) boot_layer=each_input.input)
with mixed_layer( mem.set_input(mem)
name=mem_name,
size=each_input.input.size,
act=IdentityActivation()) as mix:
mix += identity_projection(mem)
in_args.append(mem) in_args.append(mem)
assert (is_generating != has_LayerOutput) assert (is_generating != has_LayerOutput)
......
...@@ -256,7 +256,6 @@ sub_models { ...@@ -256,7 +256,6 @@ sub_models {
memories { memories {
layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group" layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group"
link_name: "__simple_gru_0__+delay1@__simple_gru_0___recurrent_group" link_name: "__simple_gru_0__+delay1@__simple_gru_0___recurrent_group"
is_sequence: false
} }
in_links { in_links {
layer_name: "__simple_gru_0___transform" layer_name: "__simple_gru_0___transform"
...@@ -266,7 +265,6 @@ sub_models { ...@@ -266,7 +265,6 @@ sub_models {
layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group" layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group"
link_name: "__simple_gru_0__" link_name: "__simple_gru_0__"
} }
target_inlinkid: -1
} }
sub_models { sub_models {
name: "__simple_gru_1___recurrent_group" name: "__simple_gru_1___recurrent_group"
...@@ -278,7 +276,6 @@ sub_models { ...@@ -278,7 +276,6 @@ sub_models {
memories { memories {
layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group" layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group"
link_name: "__simple_gru_1__+delay1@__simple_gru_1___recurrent_group" link_name: "__simple_gru_1__+delay1@__simple_gru_1___recurrent_group"
is_sequence: false
} }
in_links { in_links {
layer_name: "__simple_gru_1___transform" layer_name: "__simple_gru_1___transform"
...@@ -288,6 +285,5 @@ sub_models { ...@@ -288,6 +285,5 @@ sub_models {
layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group" layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group"
link_name: "__simple_gru_1__" link_name: "__simple_gru_1__"
} }
target_inlinkid: -1
} }
...@@ -341,12 +341,10 @@ sub_models { ...@@ -341,12 +341,10 @@ sub_models {
memories { memories {
layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0__+delay1@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0__+delay1@__lstm_group_0___recurrent_group"
is_sequence: false
} }
memories { memories {
layer_name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0___state@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0___state+delay1@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0___state+delay1@__lstm_group_0___recurrent_group"
is_sequence: false
} }
in_links { in_links {
layer_name: "__mixed_0__" layer_name: "__mixed_0__"
...@@ -356,7 +354,6 @@ sub_models { ...@@ -356,7 +354,6 @@ sub_models {
layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0__" link_name: "__lstm_group_0__"
} }
target_inlinkid: -1
} }
sub_models { sub_models {
name: "__lstm_group_1___recurrent_group" name: "__lstm_group_1___recurrent_group"
...@@ -371,12 +368,10 @@ sub_models { ...@@ -371,12 +368,10 @@ sub_models {
memories { memories {
layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group" layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group"
link_name: "__lstm_group_1__+delay1@__lstm_group_1___recurrent_group" link_name: "__lstm_group_1__+delay1@__lstm_group_1___recurrent_group"
is_sequence: false
} }
memories { memories {
layer_name: "__lstm_group_1___state@__lstm_group_1___recurrent_group" layer_name: "__lstm_group_1___state@__lstm_group_1___recurrent_group"
link_name: "__lstm_group_1___state+delay1@__lstm_group_1___recurrent_group" link_name: "__lstm_group_1___state+delay1@__lstm_group_1___recurrent_group"
is_sequence: false
} }
in_links { in_links {
layer_name: "__mixed_1__" layer_name: "__mixed_1__"
...@@ -386,6 +381,5 @@ sub_models { ...@@ -386,6 +381,5 @@ sub_models {
layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group" layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group"
link_name: "__lstm_group_1__" link_name: "__lstm_group_1__"
} }
target_inlinkid: -1
} }
...@@ -618,7 +618,6 @@ sub_models { ...@@ -618,7 +618,6 @@ sub_models {
memories { memories {
layer_name: "rnn_forward@__recurrent_group_0__" layer_name: "rnn_forward@__recurrent_group_0__"
link_name: "rnn_forward+delay1@__recurrent_group_0__" link_name: "rnn_forward+delay1@__recurrent_group_0__"
is_sequence: false
} }
in_links { in_links {
layer_name: "seq_input" layer_name: "seq_input"
...@@ -628,7 +627,6 @@ sub_models { ...@@ -628,7 +627,6 @@ sub_models {
layer_name: "rnn_forward@__recurrent_group_0__" layer_name: "rnn_forward@__recurrent_group_0__"
link_name: "rnn_forward" link_name: "rnn_forward"
} }
target_inlinkid: -1
} }
sub_models { sub_models {
name: "__recurrent_group_1__" name: "__recurrent_group_1__"
...@@ -640,7 +638,6 @@ sub_models { ...@@ -640,7 +638,6 @@ sub_models {
memories { memories {
layer_name: "rnn_back@__recurrent_group_1__" layer_name: "rnn_back@__recurrent_group_1__"
link_name: "rnn_back+delay1@__recurrent_group_1__" link_name: "rnn_back+delay1@__recurrent_group_1__"
is_sequence: false
} }
in_links { in_links {
layer_name: "seq_input" layer_name: "seq_input"
...@@ -650,7 +647,6 @@ sub_models { ...@@ -650,7 +647,6 @@ sub_models {
layer_name: "rnn_back@__recurrent_group_1__" layer_name: "rnn_back@__recurrent_group_1__"
link_name: "rnn_back" link_name: "rnn_back"
} }
target_inlinkid: -1
} }
sub_models { sub_models {
name: "__recurrent_group_2__" name: "__recurrent_group_2__"
...@@ -662,7 +658,6 @@ sub_models { ...@@ -662,7 +658,6 @@ sub_models {
memories { memories {
layer_name: "rnn_subseq_forward@__recurrent_group_2__" layer_name: "rnn_subseq_forward@__recurrent_group_2__"
link_name: "rnn_subseq_forward+delay1@__recurrent_group_2__" link_name: "rnn_subseq_forward+delay1@__recurrent_group_2__"
is_sequence: false
} }
in_links { in_links {
layer_name: "sub_seq_input" layer_name: "sub_seq_input"
...@@ -672,7 +667,6 @@ sub_models { ...@@ -672,7 +667,6 @@ sub_models {
layer_name: "rnn_subseq_forward@__recurrent_group_2__" layer_name: "rnn_subseq_forward@__recurrent_group_2__"
link_name: "rnn_subseq_forward" link_name: "rnn_subseq_forward"
} }
target_inlinkid: -1
} }
sub_models { sub_models {
name: "__lstm_group_0___recurrent_group" name: "__lstm_group_0___recurrent_group"
...@@ -687,12 +681,10 @@ sub_models { ...@@ -687,12 +681,10 @@ sub_models {
memories { memories {
layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0__+delay1@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0__+delay1@__lstm_group_0___recurrent_group"
is_sequence: false
} }
memories { memories {
layer_name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0___state@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0___state+delay1@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0___state+delay1@__lstm_group_0___recurrent_group"
is_sequence: false
} }
in_links { in_links {
layer_name: "__mixed_0__" layer_name: "__mixed_0__"
...@@ -702,7 +694,6 @@ sub_models { ...@@ -702,7 +694,6 @@ sub_models {
layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0__" link_name: "__lstm_group_0__"
} }
target_inlinkid: -1
} }
sub_models { sub_models {
name: "__gru_group_0___recurrent_group" name: "__gru_group_0___recurrent_group"
...@@ -714,7 +705,6 @@ sub_models { ...@@ -714,7 +705,6 @@ sub_models {
memories { memories {
layer_name: "__gru_group_0__@__gru_group_0___recurrent_group" layer_name: "__gru_group_0__@__gru_group_0___recurrent_group"
link_name: "__gru_group_0__+delay1@__gru_group_0___recurrent_group" link_name: "__gru_group_0__+delay1@__gru_group_0___recurrent_group"
is_sequence: false
} }
in_links { in_links {
layer_name: "__mixed_1__" layer_name: "__mixed_1__"
...@@ -724,7 +714,6 @@ sub_models { ...@@ -724,7 +714,6 @@ sub_models {
layer_name: "__gru_group_0__@__gru_group_0___recurrent_group" layer_name: "__gru_group_0__@__gru_group_0___recurrent_group"
link_name: "__gru_group_0__" link_name: "__gru_group_0__"
} }
target_inlinkid: -1
} }
sub_models { sub_models {
name: "__recurrent_group_3__" name: "__recurrent_group_3__"
...@@ -736,7 +725,6 @@ sub_models { ...@@ -736,7 +725,6 @@ sub_models {
memories { memories {
layer_name: "__fc_layer_0__@__recurrent_group_3__" layer_name: "__fc_layer_0__@__recurrent_group_3__"
link_name: "__memory_6__@__recurrent_group_3__" link_name: "__memory_6__@__recurrent_group_3__"
is_sequence: false
} }
in_links { in_links {
layer_name: "seq_input" layer_name: "seq_input"
...@@ -746,6 +734,5 @@ sub_models { ...@@ -746,6 +734,5 @@ sub_models {
layer_name: "__fc_layer_0__@__recurrent_group_3__" layer_name: "__fc_layer_0__@__recurrent_group_3__"
link_name: "__fc_layer_0__" link_name: "__fc_layer_0__"
} }
target_inlinkid: -1
} }
...@@ -260,7 +260,7 @@ def parse_network(output_layers, extra_layers=None): ...@@ -260,7 +260,7 @@ def parse_network(output_layers, extra_layers=None):
else: else:
extra_layers = [] extra_layers = []
layer_names = __get_used_layers__(output_layers + extra_layers) layer_names = __get_used_layers__(list(output_layers) + list(extra_layers))
submodel_names = __get_used_submodels__(layer_names) submodel_names = __get_used_submodels__(layer_names)
submodel_names.add('root') submodel_names.add('root')
evaluator_names = __get_used_evaluators__(layer_names) evaluator_names = __get_used_evaluators__(layer_names)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册