提交 c9bb48b3 编写于 作者: Q qiaolongfei

support calculate size

上级 b400c8f0
......@@ -22,7 +22,7 @@ class Layer(object):
def __init__(self, name=None, size=None, parent_layers=None):
assert isinstance(parent_layers, dict)
self.name = name
self.size = size
self.__contex__ = {}
self.__parent_layers__ = parent_layers
def to_proto(self, context):
......@@ -44,7 +44,7 @@ class Layer(object):
return self.to_proto_impl(**kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(**kwargs)
self.__contex__ = context
if self.use_context_name():
return context[self.context_name()]
else:
......@@ -64,6 +64,9 @@ class Layer(object):
def use_context_name(self):
return False
def calcalted_size(self):
return self.__contex__[self.context_name()].size
def __convert_to_v2__(method_name, parent_names, is_default_name=True):
if is_default_name:
......
......@@ -197,6 +197,10 @@ class MemoryV2(WithExtraParent):
val = locs[key]
if isinstance(val, RecurrentLayerInput):
begin_of_current_rnn.append(val)
elif isinstance(val, collections.Sequence):
for v in val:
if isinstance(v, RecurrentLayerInput):
begin_of_current_rnn.append(v)
if begin_of_current_rnn:
break
......@@ -216,7 +220,13 @@ class MemoryV2(WithExtraParent):
if self.__boot_layer_name__ is not None:
args['boot_layer'] = context[self.__boot_layer_name__]
return conf_helps.memory(name=self.name, size=self.size, **args)
if callable(self.size):
real_size = self.size()
else:
real_size = self.size
args['size'] = real_size
return conf_helps.memory(name=self.name, **args)
def context_name(self):
return self.name + "#memory"
......@@ -311,6 +321,12 @@ class MixedLayerV2(Layer):
args[each] = kwargs[each]
for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each]
size = args.get('size', None)
if callable(size):
real_size = size()
else:
real_size = size
args['size'] = real_size
return getattr(conf_helps, self.__method_name__)(**args)
......@@ -363,53 +379,15 @@ class RecurrentLayerOutput(Layer):
RecurrentLayerGroupEnd(name=self.__recurrent_name__)
@wrap_name_default()
def recurrent_group(step, input, name=None):
if not isinstance(input, collections.Sequence):
input = [input]
# TODO(qiaolongfei) convert StaticInput to memory according to v2 recurrent_group
for i in xrange(len(input)):
cur_input = input[i]
if isinstance(cur_input, StaticInputV2):
input[i] = cur_input.input
actual_input = [
RecurrentLayerInput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_inputs': input})
for i in xrange(len(input))
]
actual_output = step(*actual_input)
if not isinstance(actual_output, collections.Sequence):
actual_output = [actual_output]
retv = [
RecurrentLayerOutput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_outputs': actual_output})
for i in xrange(len(actual_output))
]
if len(retv) == 1:
return retv[0]
else:
return retv
LayerV2 = Layer
data = DataLayerV2
AggregateLevel = conf_helps.layers.AggregateLevel
ExpandLevel = conf_helps.layers.ExpandLevel
recurrent_group = recurrent_group
memory = MemoryV2
def __layer_name_mapping__(inname):
if inname in ['data_layer', 'memory', 'mixed_layer']:
if inname in ['data_layer', 'memory', 'mixed_layer', 'recurrent_group']:
# Do Not handle these layers
return
elif inname == 'maxid_layer':
......@@ -469,3 +447,55 @@ operator_list = [
for op in operator_list:
globals()[op[0]] = __convert_to_v2__(
op[0], parent_names=op[1], is_default_name=False)
@wrap_name_default()
def recurrent_group(step, input, name=None):
if not isinstance(input, collections.Sequence):
input = [input]
non_static_inputs = filter(lambda x: not isinstance(x, StaticInputV2),
input)
actual_input = [
RecurrentLayerInput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_inputs': non_static_inputs})
for i in xrange(len(non_static_inputs))
]
def __real_step__(*args):
rnn_input = list(args)
static_inputs = filter(lambda x: isinstance(x, StaticInputV2), input)
for static_input in static_inputs:
mem_name = "__%s_memory__" % static_input.input.name
print memory
mem = memory(
name=mem_name,
is_seq=static_input.is_seq,
size=static_input.input.calcalted_size,
boot_layer=static_input.input)
with mixed(
name=mem_name,
size=static_input.input.calcalted_size,
act=activation.Identity()) as mix:
mix += identity_projection(input=mem)
rnn_input.insert(input.index(static_input), mix)
return step(*rnn_input)
actual_output = __real_step__(*actual_input)
if not isinstance(actual_output, collections.Sequence):
actual_output = [actual_output]
retv = [
RecurrentLayerOutput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_outputs': actual_output})
for i in xrange(len(actual_output))
]
if len(retv) == 1:
return retv[0]
else:
return retv
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册