提交 a4a599ab 编写于 作者: Q qiaolongfei

refine code, remove duplicate code between layer and with_extra_parent_layer

上级 e64418c7
import sys
import paddle.v2 as paddle
import paddle.v2.layer.beam_search as beam_search
import paddle.v2.layers.beam_search as beam_search
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating):
......@@ -138,7 +138,7 @@ def main():
source_dict_dim = target_dict_dim = dict_size
# define network topology
cost = seqToseq_net(source_dict_dim, target_dict_dim)
cost = seqToseq_net(source_dict_dim, target_dict_dim, False)
parameters = paddle.parameters.create(cost)
# define optimize method and trainer
......
......@@ -98,10 +98,8 @@ class Layer(object):
kwargs[layer_name] = v1_layer
# parse myself.
ret_val = self.to_proto_impl(**kwargs)
if self.context_name() is not None and \
self.context_name() not in context:
ret_val = self.to_proto_impl(context=context, **kwargs)
if self.context_name() is not None and self.context_name() not in context:
context[self.context_name()] = ret_val
# parse children.
......@@ -124,7 +122,7 @@ class Layer(object):
else:
return context[self.name]
def to_proto_impl(self, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
raise NotImplementedError()
def context_name(self):
......@@ -188,7 +186,7 @@ def __convert_to_v2__(method_name,
if wrapper is not None:
__init__ = wrapper(__init__)
def to_proto_impl(self, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
......
......@@ -98,7 +98,7 @@ class DataLayerV2(Layer):
super(DataLayerV2, self).__init__(name=name, parent_layers=dict())
def to_proto_impl(self, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
args = dict()
args['size'] = self.type.dim
for each in kwargs:
......@@ -142,46 +142,11 @@ class WithExtraParent(Layer):
else:
return context[self.name]
# parse parents
kwargs = dict()
# parse extra_parent
for p in self.__extra_parent__:
p.to_proto(context=context)
for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name],
collections.Sequence):
v1_layer = self.__parent_layers__[layer_name].to_proto(
context=context)
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
# parse self
if self.context_name() is None:
return self.to_proto_impl(context=context, **kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(
context=context, **kwargs)
# parse children.
aaa = self.__children_layers__
for layer, pnames in self.__children_layers__:
drop = False
# child will only be parsed if all parents are in context.
for pname in pnames:
if pname not in context:
drop = True
break
if drop:
continue
layer.to_proto(context=context)
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
return super(WithExtraParent, self).to_proto(context=context)
class MemoryV2(WithExtraParent):
......@@ -307,7 +272,7 @@ class MixedLayerV2(Layer):
def __exit__(self, *args, **kwargs):
self.finalized = True
def to_proto_impl(self, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
......@@ -371,7 +336,7 @@ class RecurrentLayerOutput(Layer):
def context_name(self):
return self.__recurrent_name__ + ".end"
def to_proto_impl(self, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
for l in self.__parents__:
RecurrentLayerGroupSetOutLink(l.name)
RecurrentLayerGroupEnd(name=self.__recurrent_name__)
......
......@@ -48,7 +48,7 @@ class RecurrentLayerGroupSetGeneratorV2(Layer):
super(RecurrentLayerGroupSetGeneratorV2, self).__init__(
name=eos_name, parent_layers={})
def to_proto_impl(self, **kwargs):
def to_proto_impl(self, context=None, **kwargs):
RecurrentLayerGroupSetGenerator(
Generator(
eos_layer_name=self.eos_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册