提交 e64418c7 编写于 作者: Q qiaolongfei

support beam_search, fix mix bug

上级 30aded19
......@@ -76,10 +76,6 @@ class Layer(object):
"""
function to set proto attribute
"""
print "======"
# print self.name
print self.__parent_layers__
# print self.__context__
self.__context__ = context
# short cut if myself is parsed before.
......
......@@ -135,10 +135,14 @@ class WithExtraParent(Layer):
"""
function to set proto attribute
"""
print "*************"
# print context
print self.name
print self.__extra_parent__
# short cut if myself is parsed before.
if self.context_name() in context:
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
# parse parents
kwargs = dict()
for p in self.__extra_parent__:
p.to_proto(context=context)
......@@ -153,12 +157,27 @@ class WithExtraParent(Layer):
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
# parse self
if self.context_name() is None:
return self.to_proto_impl(context=context, **kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(
context=context, **kwargs)
# parse children.
aaa = self.__children_layers__
for layer, pnames in self.__children_layers__:
drop = False
# child will only be parsed if all parents are in context.
for pname in pnames:
if pname not in context:
drop = True
break
if drop:
continue
layer.to_proto(context=context)
if self.use_context_name():
return context[self.context_name()]
else:
......@@ -456,7 +475,8 @@ def recurrent_group(step, input, name=None):
size=static_input.input.calculate_size,
act=activation.Identity()) as mix:
mix += identity_projection(input=mem)
rnn_input.insert(input.index(static_input), mix)
mem.append_child(layer=mix, parent_names=[mem.context_name()])
rnn_input.insert(input.index(static_input), mem)
return step(*rnn_input)
actual_output = __real_step__(*actual_input)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册