From b400c8f02c76ce74828cc999d6bef335cca18a57 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 2 Mar 2017 11:47:33 +0800 Subject: [PATCH] update to latest --- python/paddle/v2/config_base.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/config_base.py b/python/paddle/v2/config_base.py index 035f96b0f..be3e39a06 100644 --- a/python/paddle/v2/config_base.py +++ b/python/paddle/v2/config_base.py @@ -19,9 +19,10 @@ import paddle.trainer_config_helpers as conf_helps class Layer(object): - def __init__(self, name=None, parent_layers=None): + def __init__(self, name=None, size=None, parent_layers=None): assert isinstance(parent_layers, dict) self.name = name + self.size = size self.__parent_layers__ = parent_layers def to_proto(self, context): @@ -39,16 +40,30 @@ class Layer(object): self.__parent_layers__[layer_name]) kwargs[layer_name] = v1_layer - if self.name is None: + if self.context_name() is None: return self.to_proto_impl(**kwargs) - elif self.name not in context: - context[self.name] = self.to_proto_impl(**kwargs) + elif self.context_name() not in context: + context[self.context_name()] = self.to_proto_impl(**kwargs) - return context[self.name] + if self.use_context_name(): + return context[self.context_name()] + else: + return context[self.name] def to_proto_impl(self, **kwargs): raise NotImplementedError() + def context_name(self): + """ + Context name means the context which stores `to_proto_impl` result. + If multiple layer share same context_name, the `to_proto_impl` of them + will be invoked only once. + """ + return self.name + + def use_context_name(self): + return False + def __convert_to_v2__(method_name, parent_names, is_default_name=True): if is_default_name: @@ -69,7 +84,8 @@ def __convert_to_v2__(method_name, parent_names, is_default_name=True): other_kwargs[key] = kwargs[key] name = kwargs.get('name', None) - super(V2LayerImpl, self).__init__(name, parent_layers) + size = kwargs.get('size', None) + super(V2LayerImpl, self).__init__(name, size, parent_layers) self.__other_kwargs__ = other_kwargs if wrapper is not None: -- GitLab