提交 b400c8f0 编写于 作者: Q qiaolongfei

update to latest

上级 22f7b9ab
......@@ -19,9 +19,10 @@ import paddle.trainer_config_helpers as conf_helps
class Layer(object):
def __init__(self, name=None, parent_layers=None):
def __init__(self, name=None, size=None, parent_layers=None):
assert isinstance(parent_layers, dict)
self.name = name
self.size = size
self.__parent_layers__ = parent_layers
def to_proto(self, context):
......@@ -39,16 +40,30 @@ class Layer(object):
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
if self.name is None:
if self.context_name() is None:
return self.to_proto_impl(**kwargs)
elif self.name not in context:
context[self.name] = self.to_proto_impl(**kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(**kwargs)
return context[self.name]
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
def to_proto_impl(self, **kwargs):
raise NotImplementedError()
def context_name(self):
"""
Context name means the context which stores `to_proto_impl` result.
If multiple layer share same context_name, the `to_proto_impl` of them
will be invoked only once.
"""
return self.name
def use_context_name(self):
return False
def __convert_to_v2__(method_name, parent_names, is_default_name=True):
if is_default_name:
......@@ -69,7 +84,8 @@ def __convert_to_v2__(method_name, parent_names, is_default_name=True):
other_kwargs[key] = kwargs[key]
name = kwargs.get('name', None)
super(V2LayerImpl, self).__init__(name, parent_layers)
size = kwargs.get('size', None)
super(V2LayerImpl, self).__init__(name, size, parent_layers)
self.__other_kwargs__ = other_kwargs
if wrapper is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册