提交 5fc572c2 编写于 作者: Y Yu Yang

Complete Memory

上级 6b199367
...@@ -3474,8 +3474,6 @@ def update_g_config(): ...@@ -3474,8 +3474,6 @@ def update_g_config():
for name in g_config.model_config.output_layer_names: for name in g_config.model_config.output_layer_names:
assert name in g_layer_map, \ assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name 'input name "%s" does not correspond to a layer name' % name
for hook in _parse_config_hooks:
hook()
return g_config return g_config
...@@ -3487,8 +3485,8 @@ def parse_config(trainer_config, config_arg_str): ...@@ -3487,8 +3485,8 @@ def parse_config(trainer_config, config_arg_str):
passed to config script as a dictionary CONFIG_ARGS passed to config script as a dictionary CONFIG_ARGS
''' '''
init_config_environment() init_config_environment()
# for hook in _parse_config_hooks: for hook in _parse_config_hooks:
# hook() hook()
config_args = {} config_args = {}
......
...@@ -67,7 +67,7 @@ paddle.v2.parameters.create, no longer exposed to users. ...@@ -67,7 +67,7 @@ paddle.v2.parameters.create, no longer exposed to users.
""" """
import collections import collections
import inspect
import paddle.trainer_config_helpers as conf_helps import paddle.trainer_config_helpers as conf_helps
from paddle.trainer_config_helpers.config_parser_utils import \ from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as __parse__ parse_network_config as __parse__
...@@ -216,31 +216,83 @@ class DataLayerV2(Layer): ...@@ -216,31 +216,83 @@ class DataLayerV2(Layer):
return getattr(conf_helps, self.__method_name__)(name=self.name, **args) return getattr(conf_helps, self.__method_name__)(name=self.name, **args)
class MemoryV2(Layer): class WithExtraParent(Layer):
def __init__(self, name, size, **kwargs): def extra_parent(self):
self.name = name return self.__extra_parent__
self.size = size
parent_names = ['boot_layer'] def __init__(self, name=None, parent_layers=None):
parent_layers = dict() self.__extra_parent__ = []
other_kwargs = dict() super(WithExtraParent, self).__init__(name, parent_layers)
for pname in parent_names:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys(): def append_extra_parent(self, parent):
if key not in parent_names: self.__extra_parent__.append(parent)
other_kwargs[key] = kwargs[key]
super(MemoryV2, self).__init__(name=name, parent_layers=parent_layers)
self.__kwargs__ = other_kwargs
def to_proto_impl(self, **kwargs): def to_proto(self, context):
"""
function to set proto attribute
"""
kwargs = dict()
for p in self.__extra_parent__:
p.to_proto(context=context)
for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name],
collections.Sequence):
v1_layer = self.__parent_layers__[layer_name].to_proto(
context=context)
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
if self.context_name() is None:
return self.to_proto_impl(context=context, **kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(
context=context, **kwargs)
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
class MemoryV2(WithExtraParent):
def __init__(self, name, size, **kwargs):
self.name = name
self.size = size
super(MemoryV2, self).__init__(name=name, parent_layers=dict())
self.__kwargs__ = kwargs
self.__boot_layer_name__ = None
if 'boot_layer' in kwargs:
begin_of_current_rnn = []
# TODO(yuyang18): Fix inspect, it could be wrong when user invoke a
# function inside step.
st = inspect.stack()
for i in xrange(len(st)):
locs = inspect.stack()[i][0].f_locals
for val in locs.viewvalues():
if isinstance(val, RecurrentLayerInput):
begin_of_current_rnn.append(val)
if begin_of_current_rnn:
break
assert begin_of_current_rnn is not None
for extra in begin_of_current_rnn:
self.append_extra_parent(extra)
assert isinstance(extra, WithExtraParent)
extra.append_extra_parent(kwargs['boot_layer'])
self.__boot_layer_name__ = kwargs['boot_layer'].name
def to_proto_impl(self, context, **kwargs):
args = dict() args = dict()
for each in kwargs: for each in kwargs:
args[each] = kwargs[each] args[each] = kwargs[each]
for each in self.__kwargs__: for each in self.__kwargs__:
args[each] = self.__kwargs__[each] args[each] = self.__kwargs__[each]
if self.__boot_layer_name__ is not None:
args['boot_layer'] = context[self.__boot_layer_name__]
return conf_helps.memory(name=self.name, size=self.size, **args) return conf_helps.memory(name=self.name, size=self.size, **args)
def context_name(self): def context_name(self):
...@@ -328,7 +380,7 @@ class MixedLayerV2(Layer): ...@@ -328,7 +380,7 @@ class MixedLayerV2(Layer):
self.__inputs__.append(other) self.__inputs__.append(other)
return self return self
else: else:
raise MixedLayerTypeV2.AddToSealedMixedLayerExceptionV2() raise MixedLayerV2.AddToSealedMixedLayerExceptionV2()
def __enter__(self): def __enter__(self):
assert len(self.__inputs__) == 0 assert len(self.__inputs__) == 0
...@@ -359,11 +411,10 @@ def mixed(size=0, ...@@ -359,11 +411,10 @@ def mixed(size=0,
return MixedLayerV2(size, input, name, act, bias_attr, layer_attr) return MixedLayerV2(size, input, name, act, bias_attr, layer_attr)
class RecurrentLayerInput(Layer): class RecurrentLayerInput(WithExtraParent):
def __init__(self, recurrent_name, index, parent_layers): def __init__(self, recurrent_name, index, parent_layers):
assert len(parent_layers) == 1 assert len(parent_layers) == 1
self.__parents__ = parent_layers.values()[0] self.__parents__ = parent_layers.values()[0]
print self.__parents__, parent_layers
super(RecurrentLayerInput, self).__init__( super(RecurrentLayerInput, self).__init__(
name=self.__parents__[index].name, parent_layers=parent_layers) name=self.__parents__[index].name, parent_layers=parent_layers)
self.__recurrent_name__ = recurrent_name self.__recurrent_name__ = recurrent_name
...@@ -371,7 +422,7 @@ class RecurrentLayerInput(Layer): ...@@ -371,7 +422,7 @@ class RecurrentLayerInput(Layer):
def context_name(self): def context_name(self):
return self.__recurrent_name__ + ".begin" return self.__recurrent_name__ + ".begin"
def to_proto_impl(self, **kwargs): def to_proto_impl(self, context, **kwargs):
model_type('recurrent_nn') model_type('recurrent_nn')
RecurrentLayerGroupWithoutOutLinksBegin( RecurrentLayerGroupWithoutOutLinksBegin(
name=self.__recurrent_name__, name=self.__recurrent_name__,
...@@ -458,8 +509,10 @@ def __layer_name_mapping__(inname): ...@@ -458,8 +509,10 @@ def __layer_name_mapping__(inname):
def __layer_name_mapping_parent_names__(inname): def __layer_name_mapping_parent_names__(inname):
all_args = getattr(conf_helps, inname).argspec.args all_args = getattr(conf_helps, inname).argspec.args
return filter( return filter(
lambda x: x in ['input1', 'input2','label', 'input', 'a', 'b', 'expand_as', lambda x: x in ['input1', 'input2', 'label', 'input', 'a', 'b',
'weights', 'vectors', 'weight', 'score', 'left', 'right'], 'expand_as',
'weights', 'vectors', 'weight', 'score', 'left',
'right'],
all_args) all_args)
......
...@@ -106,9 +106,21 @@ class RNNTest(unittest.TestCase): ...@@ -106,9 +106,21 @@ class RNNTest(unittest.TestCase):
return str(parse_network(test)) return str(parse_network(test))
def parse_new_rnn(): def parse_new_rnn():
data = layer.data(
name="word", type=data_type.dense_vector(dict_dim))
label = layer.data(
name="label", type=data_type.dense_vector(label_dim))
emb = layer.embedding(input=data, size=word_dim)
boot_layer = layer.data(
name="boot", type=data_type.dense_vector(10))
boot_layer = layer.fc(name='wtf', input=boot_layer, size=10)
def step(y, wid): def step(y, wid):
z = layer.embedding(input=wid, size=word_dim) z = layer.embedding(input=wid, size=word_dim)
mem = layer.memory(name="rnn_state", size=hidden_dim) mem = layer.memory(
name="rnn_state", size=hidden_dim, boot_layer=boot_layer)
out = layer.fc(input=[y, z, mem], out = layer.fc(input=[y, z, mem],
size=hidden_dim, size=hidden_dim,
act=activation.Tanh(), act=activation.Tanh(),
...@@ -116,11 +128,6 @@ class RNNTest(unittest.TestCase): ...@@ -116,11 +128,6 @@ class RNNTest(unittest.TestCase):
name="rnn_state") name="rnn_state")
return out return out
data = layer.data(
name="word", type=data_type.dense_vector(dict_dim))
label = layer.data(
name="label", type=data_type.dense_vector(label_dim))
emb = layer.embedding(input=data, size=word_dim)
out = layer.recurrent_group( out = layer.recurrent_group(
name="rnn", step=step, input=[emb, data]) name="rnn", step=step, input=[emb, data])
...@@ -134,9 +141,11 @@ class RNNTest(unittest.TestCase): ...@@ -134,9 +141,11 @@ class RNNTest(unittest.TestCase):
return str(layer.parse_network(cost)) return str(layer.parse_network(cost))
diff = difflib.unified_diff(parse_old_rnn().splitlines(1), with open("/Users/baidu/old.out", 'w') as f:
parse_new_rnn().splitlines(1)) print >> f, parse_old_rnn()
print ''.join(diff) with open("/Users/baidu/new.out", "w") as f:
print >> f, parse_new_rnn()
# print ''.join(diff)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册