提交 e64418c7 编写于 作者: Q qiaolongfei

support beam_search, fix mix bug

上级 30aded19
...@@ -76,10 +76,6 @@ class Layer(object): ...@@ -76,10 +76,6 @@ class Layer(object):
""" """
function to set proto attribute function to set proto attribute
""" """
print "======"
# print self.name
print self.__parent_layers__
# print self.__context__
self.__context__ = context self.__context__ = context
# short cut if myself is parsed before. # short cut if myself is parsed before.
......
...@@ -135,10 +135,14 @@ class WithExtraParent(Layer): ...@@ -135,10 +135,14 @@ class WithExtraParent(Layer):
""" """
function to set proto attribute function to set proto attribute
""" """
print "*************" # short cut if myself is parsed before.
# print context if self.context_name() in context:
print self.name if self.use_context_name():
print self.__extra_parent__ return context[self.context_name()]
else:
return context[self.name]
# parse parents
kwargs = dict() kwargs = dict()
for p in self.__extra_parent__: for p in self.__extra_parent__:
p.to_proto(context=context) p.to_proto(context=context)
...@@ -153,12 +157,27 @@ class WithExtraParent(Layer): ...@@ -153,12 +157,27 @@ class WithExtraParent(Layer):
self.__parent_layers__[layer_name]) self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer kwargs[layer_name] = v1_layer
# parse self
if self.context_name() is None: if self.context_name() is None:
return self.to_proto_impl(context=context, **kwargs) return self.to_proto_impl(context=context, **kwargs)
elif self.context_name() not in context: elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl( context[self.context_name()] = self.to_proto_impl(
context=context, **kwargs) context=context, **kwargs)
# parse children.
aaa = self.__children_layers__
for layer, pnames in self.__children_layers__:
drop = False
# child will only be parsed if all parents are in context.
for pname in pnames:
if pname not in context:
drop = True
break
if drop:
continue
layer.to_proto(context=context)
if self.use_context_name(): if self.use_context_name():
return context[self.context_name()] return context[self.context_name()]
else: else:
...@@ -456,7 +475,8 @@ def recurrent_group(step, input, name=None): ...@@ -456,7 +475,8 @@ def recurrent_group(step, input, name=None):
size=static_input.input.calculate_size, size=static_input.input.calculate_size,
act=activation.Identity()) as mix: act=activation.Identity()) as mix:
mix += identity_projection(input=mem) mix += identity_projection(input=mem)
rnn_input.insert(input.index(static_input), mix) mem.append_child(layer=mix, parent_names=[mem.context_name()])
rnn_input.insert(input.index(static_input), mem)
return step(*rnn_input) return step(*rnn_input)
actual_output = __real_step__(*actual_input) actual_output = __real_step__(*actual_input)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册