From a4a599ab23223ec563d1ec06148270d2d2e7c32e Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 10 Apr 2017 09:56:17 +0800 Subject: [PATCH] refine code, remove duplicate code between layer and with_extra_parent_layer --- demo/seqToseq/api_train_v2.py | 4 +-- python/paddle/v2/config_base.py | 10 +++--- python/paddle/v2/layer.py | 45 +++----------------------- python/paddle/v2/layers/beam_search.py | 2 +- 4 files changed, 12 insertions(+), 49 deletions(-) diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index c53714cefd..4eb1836855 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -1,6 +1,6 @@ import sys import paddle.v2 as paddle -import paddle.v2.layer.beam_search as beam_search +import paddle.v2.layers.beam_search as beam_search def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): @@ -138,7 +138,7 @@ def main(): source_dict_dim = target_dict_dim = dict_size # define network topology - cost = seqToseq_net(source_dict_dim, target_dict_dim) + cost = seqToseq_net(source_dict_dim, target_dict_dim, False) parameters = paddle.parameters.create(cost) # define optimize method and trainer diff --git a/python/paddle/v2/config_base.py b/python/paddle/v2/config_base.py index cb98866d87..95d9a87a05 100644 --- a/python/paddle/v2/config_base.py +++ b/python/paddle/v2/config_base.py @@ -98,10 +98,8 @@ class Layer(object): kwargs[layer_name] = v1_layer # parse myself. - ret_val = self.to_proto_impl(**kwargs) - - if self.context_name() is not None and \ - self.context_name() not in context: + ret_val = self.to_proto_impl(context=context, **kwargs) + if self.context_name() is not None and self.context_name() not in context: context[self.context_name()] = ret_val # parse children. @@ -124,7 +122,7 @@ class Layer(object): else: return context[self.name] - def to_proto_impl(self, **kwargs): + def to_proto_impl(self, context=None, **kwargs): raise NotImplementedError() def context_name(self): @@ -188,7 +186,7 @@ def __convert_to_v2__(method_name, if wrapper is not None: __init__ = wrapper(__init__) - def to_proto_impl(self, **kwargs): + def to_proto_impl(self, context=None, **kwargs): args = dict() for each in kwargs: args[each] = kwargs[each] diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 8265b5c3df..d9e36baea1 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -98,7 +98,7 @@ class DataLayerV2(Layer): super(DataLayerV2, self).__init__(name=name, parent_layers=dict()) - def to_proto_impl(self, **kwargs): + def to_proto_impl(self, context=None, **kwargs): args = dict() args['size'] = self.type.dim for each in kwargs: @@ -142,46 +142,11 @@ class WithExtraParent(Layer): else: return context[self.name] - # parse parents - kwargs = dict() + # parse extra_parent for p in self.__extra_parent__: p.to_proto(context=context) - for layer_name in self.__parent_layers__: - if not isinstance(self.__parent_layers__[layer_name], - collections.Sequence): - v1_layer = self.__parent_layers__[layer_name].to_proto( - context=context) - else: - v1_layer = map(lambda x: x.to_proto(context=context), - self.__parent_layers__[layer_name]) - kwargs[layer_name] = v1_layer - - # parse self - if self.context_name() is None: - return self.to_proto_impl(context=context, **kwargs) - elif self.context_name() not in context: - context[self.context_name()] = self.to_proto_impl( - context=context, **kwargs) - - # parse children. - aaa = self.__children_layers__ - for layer, pnames in self.__children_layers__: - drop = False - - # child will only be parsed if all parents are in context. - for pname in pnames: - if pname not in context: - drop = True - break - if drop: - continue - layer.to_proto(context=context) - - if self.use_context_name(): - return context[self.context_name()] - else: - return context[self.name] + return super(WithExtraParent, self).to_proto(context=context) class MemoryV2(WithExtraParent): @@ -307,7 +272,7 @@ class MixedLayerV2(Layer): def __exit__(self, *args, **kwargs): self.finalized = True - def to_proto_impl(self, **kwargs): + def to_proto_impl(self, context=None, **kwargs): args = dict() for each in kwargs: args[each] = kwargs[each] @@ -371,7 +336,7 @@ class RecurrentLayerOutput(Layer): def context_name(self): return self.__recurrent_name__ + ".end" - def to_proto_impl(self, **kwargs): + def to_proto_impl(self, context=None, **kwargs): for l in self.__parents__: RecurrentLayerGroupSetOutLink(l.name) RecurrentLayerGroupEnd(name=self.__recurrent_name__) diff --git a/python/paddle/v2/layers/beam_search.py b/python/paddle/v2/layers/beam_search.py index 56beae7e5e..7c6dcbb227 100644 --- a/python/paddle/v2/layers/beam_search.py +++ b/python/paddle/v2/layers/beam_search.py @@ -48,7 +48,7 @@ class RecurrentLayerGroupSetGeneratorV2(Layer): super(RecurrentLayerGroupSetGeneratorV2, self).__init__( name=eos_name, parent_layers={}) - def to_proto_impl(self, **kwargs): + def to_proto_impl(self, context=None, **kwargs): RecurrentLayerGroupSetGenerator( Generator( eos_layer_name=self.eos_name, -- GitLab