提交 f9e6aa2c 编写于 作者: Q qiaolongfei

refine code

上级 c9bb48b3
......@@ -19,7 +19,7 @@ import paddle.trainer_config_helpers as conf_helps
class Layer(object):
def __init__(self, name=None, size=None, parent_layers=None):
def __init__(self, name=None, parent_layers=None):
assert isinstance(parent_layers, dict)
self.name = name
self.__contex__ = {}
......@@ -64,7 +64,12 @@ class Layer(object):
def use_context_name(self):
return False
def calcalted_size(self):
def calculate_size(self):
"""
lazy calculate size of the layer, should be called when to_proto_impl of
this layer is called.
:return:
"""
return self.__contex__[self.context_name()].size
......@@ -87,8 +92,7 @@ def __convert_to_v2__(method_name, parent_names, is_default_name=True):
other_kwargs[key] = kwargs[key]
name = kwargs.get('name', None)
size = kwargs.get('size', None)
super(V2LayerImpl, self).__init__(name, size, parent_layers)
super(V2LayerImpl, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs
if wrapper is not None:
......
......@@ -139,10 +139,10 @@ class WithExtraParent(Layer):
def extra_parent(self):
return self.__extra_parent__
def __init__(self, name=None, size=None, parent_layers=None):
def __init__(self, name=None, parent_layers=None):
self.__extra_parent__ = []
super(WithExtraParent, self).__init__(
name=name, size=size, parent_layers=parent_layers)
name=name, parent_layers=parent_layers)
def append_extra_parent(self, parent):
self.__extra_parent__.append(parent)
......@@ -178,11 +178,9 @@ class WithExtraParent(Layer):
class MemoryV2(WithExtraParent):
def __init__(self, name, size, **kwargs):
def __init__(self, name, **kwargs):
self.name = name
self.size = size
super(MemoryV2, self).__init__(
name=name, size=size, parent_layers=dict())
super(MemoryV2, self).__init__(name=name, parent_layers=dict())
self.__kwargs__ = kwargs
self.__boot_layer_name__ = None
if 'boot_layer' in kwargs:
......@@ -221,10 +219,13 @@ class MemoryV2(WithExtraParent):
if self.__boot_layer_name__ is not None:
args['boot_layer'] = context[self.__boot_layer_name__]
if callable(self.size):
real_size = self.size()
size = args.get('size', None)
if size is not None:
if callable(size):
real_size = size()
else:
real_size = self.size
real_size = size
print(real_size)
args['size'] = real_size
return conf_helps.memory(name=self.name, **args)
......@@ -298,7 +299,7 @@ class MixedLayerV2(Layer):
other_kwargs['bias_attr'] = bias_attr
other_kwargs['layer_attr'] = layer_attr
parent_layers = {"input": self.__inputs__}
super(MixedLayerV2, self).__init__(name, size, parent_layers)
super(MixedLayerV2, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs
def __iadd__(self, other):
......@@ -322,6 +323,7 @@ class MixedLayerV2(Layer):
for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each]
size = args.get('size', None)
if size is not None:
if callable(size):
real_size = size()
else:
......@@ -473,11 +475,11 @@ def recurrent_group(step, input, name=None):
mem = memory(
name=mem_name,
is_seq=static_input.is_seq,
size=static_input.input.calcalted_size,
size=static_input.input.calculate_size,
boot_layer=static_input.input)
with mixed(
name=mem_name,
size=static_input.input.calcalted_size,
size=static_input.input.calculate_size,
act=activation.Identity()) as mix:
mix += identity_projection(input=mem)
rnn_input.insert(input.index(static_input), mix)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册