From c9bb48b308807f80b3ba238cafb97ba4b0eda983 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 2 Mar 2017 15:09:26 +0800 Subject: [PATCH] support calculate size --- python/paddle/v2/config_base.py | 7 +- python/paddle/v2/layer.py | 110 ++++++++++++++++++++------------ 2 files changed, 75 insertions(+), 42 deletions(-) diff --git a/python/paddle/v2/config_base.py b/python/paddle/v2/config_base.py index be3e39a06ef..573539a30cc 100644 --- a/python/paddle/v2/config_base.py +++ b/python/paddle/v2/config_base.py @@ -22,7 +22,7 @@ class Layer(object): def __init__(self, name=None, size=None, parent_layers=None): assert isinstance(parent_layers, dict) self.name = name - self.size = size + self.__contex__ = {} self.__parent_layers__ = parent_layers def to_proto(self, context): @@ -44,7 +44,7 @@ class Layer(object): return self.to_proto_impl(**kwargs) elif self.context_name() not in context: context[self.context_name()] = self.to_proto_impl(**kwargs) - + self.__contex__ = context if self.use_context_name(): return context[self.context_name()] else: @@ -64,6 +64,9 @@ class Layer(object): def use_context_name(self): return False + def calcalted_size(self): + return self.__contex__[self.context_name()].size + def __convert_to_v2__(method_name, parent_names, is_default_name=True): if is_default_name: diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index e24244a48c9..a97518ed525 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -197,6 +197,10 @@ class MemoryV2(WithExtraParent): val = locs[key] if isinstance(val, RecurrentLayerInput): begin_of_current_rnn.append(val) + elif isinstance(val, collections.Sequence): + for v in val: + if isinstance(v, RecurrentLayerInput): + begin_of_current_rnn.append(v) if begin_of_current_rnn: break @@ -216,7 +220,13 @@ class MemoryV2(WithExtraParent): if self.__boot_layer_name__ is not None: args['boot_layer'] = context[self.__boot_layer_name__] - return conf_helps.memory(name=self.name, size=self.size, **args) + + if callable(self.size): + real_size = self.size() + else: + real_size = self.size + args['size'] = real_size + return conf_helps.memory(name=self.name, **args) def context_name(self): return self.name + "#memory" @@ -311,6 +321,12 @@ class MixedLayerV2(Layer): args[each] = kwargs[each] for each in self.__other_kwargs__: args[each] = self.__other_kwargs__[each] + size = args.get('size', None) + if callable(size): + real_size = size() + else: + real_size = size + args['size'] = real_size return getattr(conf_helps, self.__method_name__)(**args) @@ -363,53 +379,15 @@ class RecurrentLayerOutput(Layer): RecurrentLayerGroupEnd(name=self.__recurrent_name__) -@wrap_name_default() -def recurrent_group(step, input, name=None): - if not isinstance(input, collections.Sequence): - input = [input] - - # TODO(qiaolongfei) convert StaticInput to memory according to v2 recurrent_group - for i in xrange(len(input)): - cur_input = input[i] - if isinstance(cur_input, StaticInputV2): - input[i] = cur_input.input - - actual_input = [ - RecurrentLayerInput( - recurrent_name=name, - index=i, - parent_layers={'recurrent_inputs': input}) - for i in xrange(len(input)) - ] - - actual_output = step(*actual_input) - - if not isinstance(actual_output, collections.Sequence): - actual_output = [actual_output] - - retv = [ - RecurrentLayerOutput( - recurrent_name=name, - index=i, - parent_layers={'recurrent_outputs': actual_output}) - for i in xrange(len(actual_output)) - ] - if len(retv) == 1: - return retv[0] - else: - return retv - - LayerV2 = Layer data = DataLayerV2 AggregateLevel = conf_helps.layers.AggregateLevel ExpandLevel = conf_helps.layers.ExpandLevel -recurrent_group = recurrent_group memory = MemoryV2 def __layer_name_mapping__(inname): - if inname in ['data_layer', 'memory', 'mixed_layer']: + if inname in ['data_layer', 'memory', 'mixed_layer', 'recurrent_group']: # Do Not handle these layers return elif inname == 'maxid_layer': @@ -469,3 +447,55 @@ operator_list = [ for op in operator_list: globals()[op[0]] = __convert_to_v2__( op[0], parent_names=op[1], is_default_name=False) + + +@wrap_name_default() +def recurrent_group(step, input, name=None): + if not isinstance(input, collections.Sequence): + input = [input] + + non_static_inputs = filter(lambda x: not isinstance(x, StaticInputV2), + input) + actual_input = [ + RecurrentLayerInput( + recurrent_name=name, + index=i, + parent_layers={'recurrent_inputs': non_static_inputs}) + for i in xrange(len(non_static_inputs)) + ] + + def __real_step__(*args): + rnn_input = list(args) + static_inputs = filter(lambda x: isinstance(x, StaticInputV2), input) + for static_input in static_inputs: + mem_name = "__%s_memory__" % static_input.input.name + print memory + mem = memory( + name=mem_name, + is_seq=static_input.is_seq, + size=static_input.input.calcalted_size, + boot_layer=static_input.input) + with mixed( + name=mem_name, + size=static_input.input.calcalted_size, + act=activation.Identity()) as mix: + mix += identity_projection(input=mem) + rnn_input.insert(input.index(static_input), mix) + return step(*rnn_input) + + actual_output = __real_step__(*actual_input) + + if not isinstance(actual_output, collections.Sequence): + actual_output = [actual_output] + + retv = [ + RecurrentLayerOutput( + recurrent_name=name, + index=i, + parent_layers={'recurrent_outputs': actual_output}) + for i in xrange(len(actual_output)) + ] + if len(retv) == 1: + return retv[0] + else: + return retv -- GitLab