提交 6747c59f 编写于 作者: Q qiaolongfei

remove WithExtraParent, add the logic into config_base.Layer

上级 a4a599ab
...@@ -67,7 +67,16 @@ class Layer(object): ...@@ -67,7 +67,16 @@ class Layer(object):
self.name = name self.name = name
self.__context__ = {} self.__context__ = {}
self.__parent_layers__ = parent_layers self.__parent_layers__ = parent_layers
self.__children_layers__ = [] # used for evaluator. # some layer may have some extra parent layer
self.__extra_parent__ = []
# used for evaluator.
self.__children_layers__ = []
def extra_parent(self):
return self.__extra_parent__
def append_extra_parent(self, parent):
self.__extra_parent__.append(parent)
def append_child(self, layer, parent_names): def append_child(self, layer, parent_names):
self.__children_layers__.append((layer, parent_names)) self.__children_layers__.append((layer, parent_names))
...@@ -78,14 +87,20 @@ class Layer(object): ...@@ -78,14 +87,20 @@ class Layer(object):
""" """
self.__context__ = context self.__context__ = context
# short cut if myself is parsed before. # 1. short cut if this layer is parsed before.
if self.context_name() in context: if self.context_name() in context:
if self.use_context_name(): if self.use_context_name():
return context[self.context_name()] return context[self.context_name()]
else: else:
return context[self.name] return context[self.name]
# parse parent before myself # 2. parse extra_parent that is not used by this layer but must
# be parsed before this layer.
for p in self.__extra_parent__:
p.to_proto(context=context)
# 3. parse parent that is used by this layer, get the result and
# insert into kwargs of the next layer's to_proto_impl method.
kwargs = dict() kwargs = dict()
for layer_name in self.__parent_layers__: for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name], if not isinstance(self.__parent_layers__[layer_name],
...@@ -97,12 +112,12 @@ class Layer(object): ...@@ -97,12 +112,12 @@ class Layer(object):
self.__parent_layers__[layer_name]) self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer kwargs[layer_name] = v1_layer
# parse myself. # 4. parse myself and add myself into context.
ret_val = self.to_proto_impl(context=context, **kwargs) ret_val = self.to_proto_impl(context=context, **kwargs)
if self.context_name() is not None and self.context_name() not in context: if self.context_name() is not None and self.context_name() not in context:
context[self.context_name()] = ret_val context[self.context_name()] = ret_val
# parse children. # 5. parse children that should be pased after this layer.
for layer, pnames in self.__children_layers__: for layer, pnames in self.__children_layers__:
drop = False drop = False
...@@ -115,6 +130,7 @@ class Layer(object): ...@@ -115,6 +130,7 @@ class Layer(object):
continue continue
layer.to_proto(context=context) layer.to_proto(context=context)
# 6. return v1 layer result.g
if self.context_name() is None: if self.context_name() is None:
return ret_val return ret_val
elif self.use_context_name(): elif self.use_context_name():
......
...@@ -119,37 +119,7 @@ class DataLayerV2(Layer): ...@@ -119,37 +119,7 @@ class DataLayerV2(Layer):
return doc return doc
class WithExtraParent(Layer): class MemoryV2(Layer):
def extra_parent(self):
return self.__extra_parent__
def __init__(self, name=None, parent_layers=None):
self.__extra_parent__ = []
super(WithExtraParent, self).__init__(
name=name, parent_layers=parent_layers)
def append_extra_parent(self, parent):
self.__extra_parent__.append(parent)
def to_proto(self, context):
"""
function to set proto attribute
"""
# short cut if myself is parsed before.
if self.context_name() in context:
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
# parse extra_parent
for p in self.__extra_parent__:
p.to_proto(context=context)
return super(WithExtraParent, self).to_proto(context=context)
class MemoryV2(WithExtraParent):
def __init__(self, name, extra_input=None, **kwargs): def __init__(self, name, extra_input=None, **kwargs):
self.name = name self.name = name
super(MemoryV2, self).__init__(name=name, parent_layers=dict()) super(MemoryV2, self).__init__(name=name, parent_layers=dict())
...@@ -178,11 +148,10 @@ class MemoryV2(WithExtraParent): ...@@ -178,11 +148,10 @@ class MemoryV2(WithExtraParent):
assert begin_of_current_rnn is not None assert begin_of_current_rnn is not None
for extra in begin_of_current_rnn: for extra in begin_of_current_rnn:
self.append_extra_parent(extra) self.append_extra_parent(extra)
assert isinstance(extra, WithExtraParent)
extra.append_extra_parent(kwargs['boot_layer']) extra.append_extra_parent(kwargs['boot_layer'])
self.__boot_layer_name__ = kwargs['boot_layer'].name self.__boot_layer_name__ = kwargs['boot_layer'].name
def to_proto_impl(self, context, **kwargs): def to_proto_impl(self, context=None, **kwargs):
args = dict() args = dict()
for each in kwargs: for each in kwargs:
args[each] = kwargs[each] args[each] = kwargs[each]
...@@ -301,7 +270,7 @@ def mixed(size=0, ...@@ -301,7 +270,7 @@ def mixed(size=0,
return MixedLayerV2(size, input, name, act, bias_attr, layer_attr) return MixedLayerV2(size, input, name, act, bias_attr, layer_attr)
class RecurrentLayerInput(WithExtraParent): class RecurrentLayerInput(Layer):
def __init__(self, recurrent_name, index, parent_layers): def __init__(self, recurrent_name, index, parent_layers):
parents_len = len(parent_layers) parents_len = len(parent_layers)
assert parents_len <= 1 assert parents_len <= 1
...@@ -317,7 +286,7 @@ class RecurrentLayerInput(WithExtraParent): ...@@ -317,7 +286,7 @@ class RecurrentLayerInput(WithExtraParent):
def context_name(self): def context_name(self):
return self.__recurrent_name__ + ".begin" return self.__recurrent_name__ + ".begin"
def to_proto_impl(self, context, **kwargs): def to_proto_impl(self, context=None, **kwargs):
model_type('recurrent_nn') model_type('recurrent_nn')
RecurrentLayerGroupWithoutOutLinksBegin( RecurrentLayerGroupWithoutOutLinksBegin(
name=self.__recurrent_name__, name=self.__recurrent_name__,
......
...@@ -17,7 +17,6 @@ import collections ...@@ -17,7 +17,6 @@ import collections
from paddle.proto.ModelConfig_pb2 import ModelConfig from paddle.proto.ModelConfig_pb2 import ModelConfig
import layer as v2_layer import layer as v2_layer
from layer import WithExtraParent
__all__ = ['Topology'] __all__ = ['Topology']
...@@ -41,9 +40,8 @@ def __bfs_travel__(callback, *layers): ...@@ -41,9 +40,8 @@ def __bfs_travel__(callback, *layers):
__break__ = callback(each_layer) __break__ = callback(each_layer)
if __break__: if __break__:
return return
__layers__ = each_layer.__parent_layers__.values() __layers__ = each_layer.__parent_layers__.values() + \
if isinstance(each_layer, WithExtraParent): each_layer.extra_parent()
__layers__ = __layers__ + each_layer.extra_parent()
__bfs_travel__(callback, *__layers__) __bfs_travel__(callback, *__layers__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册