提交 a4a599ab 编写于 作者: Q qiaolongfei

refine code, remove duplicate code between layer and with_extra_parent_layer

上级 e64418c7
import sys import sys
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.layer.beam_search as beam_search import paddle.v2.layers.beam_search as beam_search
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): def seqToseq_net(source_dict_dim, target_dict_dim, is_generating):
...@@ -138,7 +138,7 @@ def main(): ...@@ -138,7 +138,7 @@ def main():
source_dict_dim = target_dict_dim = dict_size source_dict_dim = target_dict_dim = dict_size
# define network topology # define network topology
cost = seqToseq_net(source_dict_dim, target_dict_dim) cost = seqToseq_net(source_dict_dim, target_dict_dim, False)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# define optimize method and trainer # define optimize method and trainer
......
...@@ -98,10 +98,8 @@ class Layer(object): ...@@ -98,10 +98,8 @@ class Layer(object):
kwargs[layer_name] = v1_layer kwargs[layer_name] = v1_layer
# parse myself. # parse myself.
ret_val = self.to_proto_impl(**kwargs) ret_val = self.to_proto_impl(context=context, **kwargs)
if self.context_name() is not None and self.context_name() not in context:
if self.context_name() is not None and \
self.context_name() not in context:
context[self.context_name()] = ret_val context[self.context_name()] = ret_val
# parse children. # parse children.
...@@ -124,7 +122,7 @@ class Layer(object): ...@@ -124,7 +122,7 @@ class Layer(object):
else: else:
return context[self.name] return context[self.name]
def to_proto_impl(self, **kwargs): def to_proto_impl(self, context=None, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def context_name(self): def context_name(self):
...@@ -188,7 +186,7 @@ def __convert_to_v2__(method_name, ...@@ -188,7 +186,7 @@ def __convert_to_v2__(method_name,
if wrapper is not None: if wrapper is not None:
__init__ = wrapper(__init__) __init__ = wrapper(__init__)
def to_proto_impl(self, **kwargs): def to_proto_impl(self, context=None, **kwargs):
args = dict() args = dict()
for each in kwargs: for each in kwargs:
args[each] = kwargs[each] args[each] = kwargs[each]
......
...@@ -98,7 +98,7 @@ class DataLayerV2(Layer): ...@@ -98,7 +98,7 @@ class DataLayerV2(Layer):
super(DataLayerV2, self).__init__(name=name, parent_layers=dict()) super(DataLayerV2, self).__init__(name=name, parent_layers=dict())
def to_proto_impl(self, **kwargs): def to_proto_impl(self, context=None, **kwargs):
args = dict() args = dict()
args['size'] = self.type.dim args['size'] = self.type.dim
for each in kwargs: for each in kwargs:
...@@ -142,46 +142,11 @@ class WithExtraParent(Layer): ...@@ -142,46 +142,11 @@ class WithExtraParent(Layer):
else: else:
return context[self.name] return context[self.name]
# parse parents # parse extra_parent
kwargs = dict()
for p in self.__extra_parent__: for p in self.__extra_parent__:
p.to_proto(context=context) p.to_proto(context=context)
for layer_name in self.__parent_layers__: return super(WithExtraParent, self).to_proto(context=context)
if not isinstance(self.__parent_layers__[layer_name],
collections.Sequence):
v1_layer = self.__parent_layers__[layer_name].to_proto(
context=context)
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
# parse self
if self.context_name() is None:
return self.to_proto_impl(context=context, **kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(
context=context, **kwargs)
# parse children.
aaa = self.__children_layers__
for layer, pnames in self.__children_layers__:
drop = False
# child will only be parsed if all parents are in context.
for pname in pnames:
if pname not in context:
drop = True
break
if drop:
continue
layer.to_proto(context=context)
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
class MemoryV2(WithExtraParent): class MemoryV2(WithExtraParent):
...@@ -307,7 +272,7 @@ class MixedLayerV2(Layer): ...@@ -307,7 +272,7 @@ class MixedLayerV2(Layer):
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
self.finalized = True self.finalized = True
def to_proto_impl(self, **kwargs): def to_proto_impl(self, context=None, **kwargs):
args = dict() args = dict()
for each in kwargs: for each in kwargs:
args[each] = kwargs[each] args[each] = kwargs[each]
...@@ -371,7 +336,7 @@ class RecurrentLayerOutput(Layer): ...@@ -371,7 +336,7 @@ class RecurrentLayerOutput(Layer):
def context_name(self): def context_name(self):
return self.__recurrent_name__ + ".end" return self.__recurrent_name__ + ".end"
def to_proto_impl(self, **kwargs): def to_proto_impl(self, context=None, **kwargs):
for l in self.__parents__: for l in self.__parents__:
RecurrentLayerGroupSetOutLink(l.name) RecurrentLayerGroupSetOutLink(l.name)
RecurrentLayerGroupEnd(name=self.__recurrent_name__) RecurrentLayerGroupEnd(name=self.__recurrent_name__)
......
...@@ -48,7 +48,7 @@ class RecurrentLayerGroupSetGeneratorV2(Layer): ...@@ -48,7 +48,7 @@ class RecurrentLayerGroupSetGeneratorV2(Layer):
super(RecurrentLayerGroupSetGeneratorV2, self).__init__( super(RecurrentLayerGroupSetGeneratorV2, self).__init__(
name=eos_name, parent_layers={}) name=eos_name, parent_layers={})
def to_proto_impl(self, **kwargs): def to_proto_impl(self, context=None, **kwargs):
RecurrentLayerGroupSetGenerator( RecurrentLayerGroupSetGenerator(
Generator( Generator(
eos_layer_name=self.eos_name, eos_layer_name=self.eos_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册