提交 c9bb48b3 编写于 作者: Q qiaolongfei

support calculate size

上级 b400c8f0
...@@ -22,7 +22,7 @@ class Layer(object): ...@@ -22,7 +22,7 @@ class Layer(object):
def __init__(self, name=None, size=None, parent_layers=None): def __init__(self, name=None, size=None, parent_layers=None):
assert isinstance(parent_layers, dict) assert isinstance(parent_layers, dict)
self.name = name self.name = name
self.size = size self.__contex__ = {}
self.__parent_layers__ = parent_layers self.__parent_layers__ = parent_layers
def to_proto(self, context): def to_proto(self, context):
...@@ -44,7 +44,7 @@ class Layer(object): ...@@ -44,7 +44,7 @@ class Layer(object):
return self.to_proto_impl(**kwargs) return self.to_proto_impl(**kwargs)
elif self.context_name() not in context: elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(**kwargs) context[self.context_name()] = self.to_proto_impl(**kwargs)
self.__contex__ = context
if self.use_context_name(): if self.use_context_name():
return context[self.context_name()] return context[self.context_name()]
else: else:
...@@ -64,6 +64,9 @@ class Layer(object): ...@@ -64,6 +64,9 @@ class Layer(object):
def use_context_name(self): def use_context_name(self):
return False return False
def calcalted_size(self):
return self.__contex__[self.context_name()].size
def __convert_to_v2__(method_name, parent_names, is_default_name=True): def __convert_to_v2__(method_name, parent_names, is_default_name=True):
if is_default_name: if is_default_name:
......
...@@ -197,6 +197,10 @@ class MemoryV2(WithExtraParent): ...@@ -197,6 +197,10 @@ class MemoryV2(WithExtraParent):
val = locs[key] val = locs[key]
if isinstance(val, RecurrentLayerInput): if isinstance(val, RecurrentLayerInput):
begin_of_current_rnn.append(val) begin_of_current_rnn.append(val)
elif isinstance(val, collections.Sequence):
for v in val:
if isinstance(v, RecurrentLayerInput):
begin_of_current_rnn.append(v)
if begin_of_current_rnn: if begin_of_current_rnn:
break break
...@@ -216,7 +220,13 @@ class MemoryV2(WithExtraParent): ...@@ -216,7 +220,13 @@ class MemoryV2(WithExtraParent):
if self.__boot_layer_name__ is not None: if self.__boot_layer_name__ is not None:
args['boot_layer'] = context[self.__boot_layer_name__] args['boot_layer'] = context[self.__boot_layer_name__]
return conf_helps.memory(name=self.name, size=self.size, **args)
if callable(self.size):
real_size = self.size()
else:
real_size = self.size
args['size'] = real_size
return conf_helps.memory(name=self.name, **args)
def context_name(self): def context_name(self):
return self.name + "#memory" return self.name + "#memory"
...@@ -311,6 +321,12 @@ class MixedLayerV2(Layer): ...@@ -311,6 +321,12 @@ class MixedLayerV2(Layer):
args[each] = kwargs[each] args[each] = kwargs[each]
for each in self.__other_kwargs__: for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each] args[each] = self.__other_kwargs__[each]
size = args.get('size', None)
if callable(size):
real_size = size()
else:
real_size = size
args['size'] = real_size
return getattr(conf_helps, self.__method_name__)(**args) return getattr(conf_helps, self.__method_name__)(**args)
...@@ -363,53 +379,15 @@ class RecurrentLayerOutput(Layer): ...@@ -363,53 +379,15 @@ class RecurrentLayerOutput(Layer):
RecurrentLayerGroupEnd(name=self.__recurrent_name__) RecurrentLayerGroupEnd(name=self.__recurrent_name__)
@wrap_name_default()
def recurrent_group(step, input, name=None):
if not isinstance(input, collections.Sequence):
input = [input]
# TODO(qiaolongfei) convert StaticInput to memory according to v2 recurrent_group
for i in xrange(len(input)):
cur_input = input[i]
if isinstance(cur_input, StaticInputV2):
input[i] = cur_input.input
actual_input = [
RecurrentLayerInput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_inputs': input})
for i in xrange(len(input))
]
actual_output = step(*actual_input)
if not isinstance(actual_output, collections.Sequence):
actual_output = [actual_output]
retv = [
RecurrentLayerOutput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_outputs': actual_output})
for i in xrange(len(actual_output))
]
if len(retv) == 1:
return retv[0]
else:
return retv
LayerV2 = Layer LayerV2 = Layer
data = DataLayerV2 data = DataLayerV2
AggregateLevel = conf_helps.layers.AggregateLevel AggregateLevel = conf_helps.layers.AggregateLevel
ExpandLevel = conf_helps.layers.ExpandLevel ExpandLevel = conf_helps.layers.ExpandLevel
recurrent_group = recurrent_group
memory = MemoryV2 memory = MemoryV2
def __layer_name_mapping__(inname): def __layer_name_mapping__(inname):
if inname in ['data_layer', 'memory', 'mixed_layer']: if inname in ['data_layer', 'memory', 'mixed_layer', 'recurrent_group']:
# Do Not handle these layers # Do Not handle these layers
return return
elif inname == 'maxid_layer': elif inname == 'maxid_layer':
...@@ -469,3 +447,55 @@ operator_list = [ ...@@ -469,3 +447,55 @@ operator_list = [
for op in operator_list: for op in operator_list:
globals()[op[0]] = __convert_to_v2__( globals()[op[0]] = __convert_to_v2__(
op[0], parent_names=op[1], is_default_name=False) op[0], parent_names=op[1], is_default_name=False)
@wrap_name_default()
def recurrent_group(step, input, name=None):
if not isinstance(input, collections.Sequence):
input = [input]
non_static_inputs = filter(lambda x: not isinstance(x, StaticInputV2),
input)
actual_input = [
RecurrentLayerInput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_inputs': non_static_inputs})
for i in xrange(len(non_static_inputs))
]
def __real_step__(*args):
rnn_input = list(args)
static_inputs = filter(lambda x: isinstance(x, StaticInputV2), input)
for static_input in static_inputs:
mem_name = "__%s_memory__" % static_input.input.name
print memory
mem = memory(
name=mem_name,
is_seq=static_input.is_seq,
size=static_input.input.calcalted_size,
boot_layer=static_input.input)
with mixed(
name=mem_name,
size=static_input.input.calcalted_size,
act=activation.Identity()) as mix:
mix += identity_projection(input=mem)
rnn_input.insert(input.index(static_input), mix)
return step(*rnn_input)
actual_output = __real_step__(*actual_input)
if not isinstance(actual_output, collections.Sequence):
actual_output = [actual_output]
retv = [
RecurrentLayerOutput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_outputs': actual_output})
for i in xrange(len(actual_output))
]
if len(retv) == 1:
return retv[0]
else:
return retv
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册