提交 b400c8f0 编写于 作者: Q qiaolongfei

update to latest

上级 22f7b9ab
...@@ -19,9 +19,10 @@ import paddle.trainer_config_helpers as conf_helps ...@@ -19,9 +19,10 @@ import paddle.trainer_config_helpers as conf_helps
class Layer(object): class Layer(object):
def __init__(self, name=None, parent_layers=None): def __init__(self, name=None, size=None, parent_layers=None):
assert isinstance(parent_layers, dict) assert isinstance(parent_layers, dict)
self.name = name self.name = name
self.size = size
self.__parent_layers__ = parent_layers self.__parent_layers__ = parent_layers
def to_proto(self, context): def to_proto(self, context):
...@@ -39,16 +40,30 @@ class Layer(object): ...@@ -39,16 +40,30 @@ class Layer(object):
self.__parent_layers__[layer_name]) self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer kwargs[layer_name] = v1_layer
if self.name is None: if self.context_name() is None:
return self.to_proto_impl(**kwargs) return self.to_proto_impl(**kwargs)
elif self.name not in context: elif self.context_name() not in context:
context[self.name] = self.to_proto_impl(**kwargs) context[self.context_name()] = self.to_proto_impl(**kwargs)
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name] return context[self.name]
def to_proto_impl(self, **kwargs): def to_proto_impl(self, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def context_name(self):
"""
Context name means the context which stores `to_proto_impl` result.
If multiple layer share same context_name, the `to_proto_impl` of them
will be invoked only once.
"""
return self.name
def use_context_name(self):
return False
def __convert_to_v2__(method_name, parent_names, is_default_name=True): def __convert_to_v2__(method_name, parent_names, is_default_name=True):
if is_default_name: if is_default_name:
...@@ -69,7 +84,8 @@ def __convert_to_v2__(method_name, parent_names, is_default_name=True): ...@@ -69,7 +84,8 @@ def __convert_to_v2__(method_name, parent_names, is_default_name=True):
other_kwargs[key] = kwargs[key] other_kwargs[key] = kwargs[key]
name = kwargs.get('name', None) name = kwargs.get('name', None)
super(V2LayerImpl, self).__init__(name, parent_layers) size = kwargs.get('size', None)
super(V2LayerImpl, self).__init__(name, size, parent_layers)
self.__other_kwargs__ = other_kwargs self.__other_kwargs__ = other_kwargs
if wrapper is not None: if wrapper is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册