diff --git a/python/paddle/v2/config_base.py b/python/paddle/v2/config_base.py index 95d9a87a0583d720c2a5759882b001e860087c43..0c82e03925bb07c8486e71351736b6b94a406b2e 100644 --- a/python/paddle/v2/config_base.py +++ b/python/paddle/v2/config_base.py @@ -67,7 +67,16 @@ class Layer(object): self.name = name self.__context__ = {} self.__parent_layers__ = parent_layers - self.__children_layers__ = [] # used for evaluator. + # some layer may have some extra parent layer + self.__extra_parent__ = [] + # used for evaluator. + self.__children_layers__ = [] + + def extra_parent(self): + return self.__extra_parent__ + + def append_extra_parent(self, parent): + self.__extra_parent__.append(parent) def append_child(self, layer, parent_names): self.__children_layers__.append((layer, parent_names)) @@ -78,14 +87,20 @@ class Layer(object): """ self.__context__ = context - # short cut if myself is parsed before. + # 1. short cut if this layer is parsed before. if self.context_name() in context: if self.use_context_name(): return context[self.context_name()] else: return context[self.name] - # parse parent before myself + # 2. parse extra_parent that is not used by this layer but must + # be parsed before this layer. + for p in self.__extra_parent__: + p.to_proto(context=context) + + # 3. parse parent that is used by this layer, get the result and + # insert into kwargs of the next layer's to_proto_impl method. kwargs = dict() for layer_name in self.__parent_layers__: if not isinstance(self.__parent_layers__[layer_name], @@ -97,12 +112,12 @@ class Layer(object): self.__parent_layers__[layer_name]) kwargs[layer_name] = v1_layer - # parse myself. + # 4. parse myself and add myself into context. ret_val = self.to_proto_impl(context=context, **kwargs) if self.context_name() is not None and self.context_name() not in context: context[self.context_name()] = ret_val - # parse children. + # 5. parse children that should be pased after this layer. for layer, pnames in self.__children_layers__: drop = False @@ -115,6 +130,7 @@ class Layer(object): continue layer.to_proto(context=context) + # 6. return v1 layer result.g if self.context_name() is None: return ret_val elif self.use_context_name(): diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index d9e36baea18edf0edc6799f219715e0bba92a504..e052930c09904fecc5d68aba9588674aff2c5905 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -119,37 +119,7 @@ class DataLayerV2(Layer): return doc -class WithExtraParent(Layer): - def extra_parent(self): - return self.__extra_parent__ - - def __init__(self, name=None, parent_layers=None): - self.__extra_parent__ = [] - super(WithExtraParent, self).__init__( - name=name, parent_layers=parent_layers) - - def append_extra_parent(self, parent): - self.__extra_parent__.append(parent) - - def to_proto(self, context): - """ - function to set proto attribute - """ - # short cut if myself is parsed before. - if self.context_name() in context: - if self.use_context_name(): - return context[self.context_name()] - else: - return context[self.name] - - # parse extra_parent - for p in self.__extra_parent__: - p.to_proto(context=context) - - return super(WithExtraParent, self).to_proto(context=context) - - -class MemoryV2(WithExtraParent): +class MemoryV2(Layer): def __init__(self, name, extra_input=None, **kwargs): self.name = name super(MemoryV2, self).__init__(name=name, parent_layers=dict()) @@ -178,11 +148,10 @@ class MemoryV2(WithExtraParent): assert begin_of_current_rnn is not None for extra in begin_of_current_rnn: self.append_extra_parent(extra) - assert isinstance(extra, WithExtraParent) extra.append_extra_parent(kwargs['boot_layer']) self.__boot_layer_name__ = kwargs['boot_layer'].name - def to_proto_impl(self, context, **kwargs): + def to_proto_impl(self, context=None, **kwargs): args = dict() for each in kwargs: args[each] = kwargs[each] @@ -301,7 +270,7 @@ def mixed(size=0, return MixedLayerV2(size, input, name, act, bias_attr, layer_attr) -class RecurrentLayerInput(WithExtraParent): +class RecurrentLayerInput(Layer): def __init__(self, recurrent_name, index, parent_layers): parents_len = len(parent_layers) assert parents_len <= 1 @@ -317,7 +286,7 @@ class RecurrentLayerInput(WithExtraParent): def context_name(self): return self.__recurrent_name__ + ".begin" - def to_proto_impl(self, context, **kwargs): + def to_proto_impl(self, context=None, **kwargs): model_type('recurrent_nn') RecurrentLayerGroupWithoutOutLinksBegin( name=self.__recurrent_name__, diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index f0679c5675b0c0f24f28f3df22efd4eb51ccbb3a..702e1ce958f9dd179b7a34db39befd83904a970d 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -17,7 +17,6 @@ import collections from paddle.proto.ModelConfig_pb2 import ModelConfig import layer as v2_layer -from layer import WithExtraParent __all__ = ['Topology'] @@ -41,9 +40,8 @@ def __bfs_travel__(callback, *layers): __break__ = callback(each_layer) if __break__: return - __layers__ = each_layer.__parent_layers__.values() - if isinstance(each_layer, WithExtraParent): - __layers__ = __layers__ + each_layer.extra_parent() + __layers__ = each_layer.__parent_layers__.values() + \ + each_layer.extra_parent() __bfs_travel__(callback, *__layers__)