提交 6747c59f 编写于 作者: Q qiaolongfei

remove WithExtraParent, add the logic into config_base.Layer

上级 a4a599ab
......@@ -67,7 +67,16 @@ class Layer(object):
self.name = name
self.__context__ = {}
self.__parent_layers__ = parent_layers
self.__children_layers__ = [] # used for evaluator.
# some layer may have some extra parent layer
self.__extra_parent__ = []
# used for evaluator.
self.__children_layers__ = []
def extra_parent(self):
return self.__extra_parent__
def append_extra_parent(self, parent):
self.__extra_parent__.append(parent)
def append_child(self, layer, parent_names):
self.__children_layers__.append((layer, parent_names))
......@@ -78,14 +87,20 @@ class Layer(object):
"""
self.__context__ = context
# short cut if myself is parsed before.
# 1. short cut if this layer is parsed before.
if self.context_name() in context:
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
# parse parent before myself
# 2. parse extra_parent that is not used by this layer but must
# be parsed before this layer.
for p in self.__extra_parent__:
p.to_proto(context=context)
# 3. parse parent that is used by this layer, get the result and
# insert into kwargs of the next layer's to_proto_impl method.
kwargs = dict()
for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name],
......@@ -97,12 +112,12 @@ class Layer(object):
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
# parse myself.
# 4. parse myself and add myself into context.
ret_val = self.to_proto_impl(context=context, **kwargs)
if self.context_name() is not None and self.context_name() not in context:
context[self.context_name()] = ret_val
# parse children.
# 5. parse children that should be pased after this layer.
for layer, pnames in self.__children_layers__:
drop = False
......@@ -115,6 +130,7 @@ class Layer(object):
continue
layer.to_proto(context=context)
# 6. return v1 layer result.g
if self.context_name() is None:
return ret_val
elif self.use_context_name():
......
......@@ -119,37 +119,7 @@ class DataLayerV2(Layer):
return doc
class WithExtraParent(Layer):
def extra_parent(self):
return self.__extra_parent__
def __init__(self, name=None, parent_layers=None):
self.__extra_parent__ = []
super(WithExtraParent, self).__init__(
name=name, parent_layers=parent_layers)
def append_extra_parent(self, parent):
self.__extra_parent__.append(parent)
def to_proto(self, context):
"""
function to set proto attribute
"""
# short cut if myself is parsed before.
if self.context_name() in context:
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
# parse extra_parent
for p in self.__extra_parent__:
p.to_proto(context=context)
return super(WithExtraParent, self).to_proto(context=context)
class MemoryV2(WithExtraParent):
class MemoryV2(Layer):
def __init__(self, name, extra_input=None, **kwargs):
self.name = name
super(MemoryV2, self).__init__(name=name, parent_layers=dict())
......@@ -178,11 +148,10 @@ class MemoryV2(WithExtraParent):
assert begin_of_current_rnn is not None
for extra in begin_of_current_rnn:
self.append_extra_parent(extra)
assert isinstance(extra, WithExtraParent)
extra.append_extra_parent(kwargs['boot_layer'])
self.__boot_layer_name__ = kwargs['boot_layer'].name
def to_proto_impl(self, context, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
......@@ -301,7 +270,7 @@ def mixed(size=0,
return MixedLayerV2(size, input, name, act, bias_attr, layer_attr)
class RecurrentLayerInput(WithExtraParent):
class RecurrentLayerInput(Layer):
def __init__(self, recurrent_name, index, parent_layers):
parents_len = len(parent_layers)
assert parents_len <= 1
......@@ -317,7 +286,7 @@ class RecurrentLayerInput(WithExtraParent):
def context_name(self):
return self.__recurrent_name__ + ".begin"
def to_proto_impl(self, context, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
model_type('recurrent_nn')
RecurrentLayerGroupWithoutOutLinksBegin(
name=self.__recurrent_name__,
......
......@@ -17,7 +17,6 @@ import collections
from paddle.proto.ModelConfig_pb2 import ModelConfig
import layer as v2_layer
from layer import WithExtraParent
__all__ = ['Topology']
......@@ -41,9 +40,8 @@ def __bfs_travel__(callback, *layers):
__break__ = callback(each_layer)
if __break__:
return
__layers__ = each_layer.__parent_layers__.values()
if isinstance(each_layer, WithExtraParent):
__layers__ = __layers__ + each_layer.extra_parent()
__layers__ = each_layer.__parent_layers__.values() + \
each_layer.extra_parent()
__bfs_travel__(callback, *__layers__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册