提交 f9e6aa2c 编写于 作者: Q qiaolongfei

refine code

上级 c9bb48b3
...@@ -19,7 +19,7 @@ import paddle.trainer_config_helpers as conf_helps ...@@ -19,7 +19,7 @@ import paddle.trainer_config_helpers as conf_helps
class Layer(object): class Layer(object):
def __init__(self, name=None, size=None, parent_layers=None): def __init__(self, name=None, parent_layers=None):
assert isinstance(parent_layers, dict) assert isinstance(parent_layers, dict)
self.name = name self.name = name
self.__contex__ = {} self.__contex__ = {}
...@@ -64,7 +64,12 @@ class Layer(object): ...@@ -64,7 +64,12 @@ class Layer(object):
def use_context_name(self): def use_context_name(self):
return False return False
def calcalted_size(self): def calculate_size(self):
"""
lazy calculate size of the layer, should be called when to_proto_impl of
this layer is called.
:return:
"""
return self.__contex__[self.context_name()].size return self.__contex__[self.context_name()].size
...@@ -87,8 +92,7 @@ def __convert_to_v2__(method_name, parent_names, is_default_name=True): ...@@ -87,8 +92,7 @@ def __convert_to_v2__(method_name, parent_names, is_default_name=True):
other_kwargs[key] = kwargs[key] other_kwargs[key] = kwargs[key]
name = kwargs.get('name', None) name = kwargs.get('name', None)
size = kwargs.get('size', None) super(V2LayerImpl, self).__init__(name, parent_layers)
super(V2LayerImpl, self).__init__(name, size, parent_layers)
self.__other_kwargs__ = other_kwargs self.__other_kwargs__ = other_kwargs
if wrapper is not None: if wrapper is not None:
......
...@@ -139,10 +139,10 @@ class WithExtraParent(Layer): ...@@ -139,10 +139,10 @@ class WithExtraParent(Layer):
def extra_parent(self): def extra_parent(self):
return self.__extra_parent__ return self.__extra_parent__
def __init__(self, name=None, size=None, parent_layers=None): def __init__(self, name=None, parent_layers=None):
self.__extra_parent__ = [] self.__extra_parent__ = []
super(WithExtraParent, self).__init__( super(WithExtraParent, self).__init__(
name=name, size=size, parent_layers=parent_layers) name=name, parent_layers=parent_layers)
def append_extra_parent(self, parent): def append_extra_parent(self, parent):
self.__extra_parent__.append(parent) self.__extra_parent__.append(parent)
...@@ -178,11 +178,9 @@ class WithExtraParent(Layer): ...@@ -178,11 +178,9 @@ class WithExtraParent(Layer):
class MemoryV2(WithExtraParent): class MemoryV2(WithExtraParent):
def __init__(self, name, size, **kwargs): def __init__(self, name, **kwargs):
self.name = name self.name = name
self.size = size super(MemoryV2, self).__init__(name=name, parent_layers=dict())
super(MemoryV2, self).__init__(
name=name, size=size, parent_layers=dict())
self.__kwargs__ = kwargs self.__kwargs__ = kwargs
self.__boot_layer_name__ = None self.__boot_layer_name__ = None
if 'boot_layer' in kwargs: if 'boot_layer' in kwargs:
...@@ -221,10 +219,13 @@ class MemoryV2(WithExtraParent): ...@@ -221,10 +219,13 @@ class MemoryV2(WithExtraParent):
if self.__boot_layer_name__ is not None: if self.__boot_layer_name__ is not None:
args['boot_layer'] = context[self.__boot_layer_name__] args['boot_layer'] = context[self.__boot_layer_name__]
if callable(self.size): size = args.get('size', None)
real_size = self.size() if size is not None:
if callable(size):
real_size = size()
else: else:
real_size = self.size real_size = size
print(real_size)
args['size'] = real_size args['size'] = real_size
return conf_helps.memory(name=self.name, **args) return conf_helps.memory(name=self.name, **args)
...@@ -298,7 +299,7 @@ class MixedLayerV2(Layer): ...@@ -298,7 +299,7 @@ class MixedLayerV2(Layer):
other_kwargs['bias_attr'] = bias_attr other_kwargs['bias_attr'] = bias_attr
other_kwargs['layer_attr'] = layer_attr other_kwargs['layer_attr'] = layer_attr
parent_layers = {"input": self.__inputs__} parent_layers = {"input": self.__inputs__}
super(MixedLayerV2, self).__init__(name, size, parent_layers) super(MixedLayerV2, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs self.__other_kwargs__ = other_kwargs
def __iadd__(self, other): def __iadd__(self, other):
...@@ -322,6 +323,7 @@ class MixedLayerV2(Layer): ...@@ -322,6 +323,7 @@ class MixedLayerV2(Layer):
for each in self.__other_kwargs__: for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each] args[each] = self.__other_kwargs__[each]
size = args.get('size', None) size = args.get('size', None)
if size is not None:
if callable(size): if callable(size):
real_size = size() real_size = size()
else: else:
...@@ -473,11 +475,11 @@ def recurrent_group(step, input, name=None): ...@@ -473,11 +475,11 @@ def recurrent_group(step, input, name=None):
mem = memory( mem = memory(
name=mem_name, name=mem_name,
is_seq=static_input.is_seq, is_seq=static_input.is_seq,
size=static_input.input.calcalted_size, size=static_input.input.calculate_size,
boot_layer=static_input.input) boot_layer=static_input.input)
with mixed( with mixed(
name=mem_name, name=mem_name,
size=static_input.input.calcalted_size, size=static_input.input.calculate_size,
act=activation.Identity()) as mix: act=activation.Identity()) as mix:
mix += identity_projection(input=mem) mix += identity_projection(input=mem)
rnn_input.insert(input.index(static_input), mix) rnn_input.insert(input.index(static_input), mix)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册