提交 73af1942 编写于 作者: Q qiaolongfei

add the implementation of rnn by yuyang

上级 e9cd3867
......@@ -822,7 +822,7 @@ def data_layer(name, size, height=None, width=None, layer_attr=None):
return LayerOutput(name, LayerType.DATA, size=size)
@wrap_name_default("embedding")
@wrap_name_default("embedding_layer")
@wrap_param_attr_default()
@layer_support(ERROR_CLIPPING)
def embedding_layer(input, size, name=None, param_attr=None, layer_attr=None):
......
......@@ -76,6 +76,9 @@ from paddle.trainer_config_helpers.default_decorators import \
wrap_bias_attr_default
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
from paddle.trainer_config_helpers.layers import layer_support
from paddle.trainer.config_parser import \
RecurrentLayerGroupWithoutOutLinksBegin, RecurrentLayerGroupSetOutLink, \
RecurrentLayerGroupEnd, model_type
import activation
import data_type
......@@ -126,21 +129,28 @@ class Layer(object):
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
if self.name is None:
if self.context_name() is None:
return self.to_proto_impl(**kwargs)
elif isinstance(self, MemoryV2):
name = self.name + "#__memory__"
if name not in context:
context[name] = self.to_proto_impl(**kwargs)
return context[name]
if self.name not in context:
context[self.name] = self.to_proto_impl(**kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(**kwargs)
return context[self.name]
def to_proto_impl(self, **kwargs):
raise NotImplementedError()
def context_name(self):
"""
Context name means the context which stores `to_proto_impl` result.
If multiple layer share same context_name, the `to_proto_impl` of them
will be invoked only once.
"""
return self.name
def __convert_to_v2__(method_name, parent_names, is_default_name=True):
if is_default_name:
......@@ -231,6 +241,9 @@ class MemoryV2(Layer):
return conf_helps.memory(name=self.name, size=self.size, **args)
def context_name(self):
return self.name + "#memory"
class LayerOutputV2(Layer):
"""
......@@ -249,60 +262,20 @@ class LayerOutputV2(Layer):
class StaticInputV2(Layer):
def __init__(self, **kwargs):
self.__parent_names__ = ['input']
other_kwargs = dict()
parent_layers = dict()
for pname in self.__parent_names__:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in self.__parent_names__:
other_kwargs[key] = kwargs[key]
self.__kwargs__ = other_kwargs
super(StaticInputV2, self).__init__(parent_layers=parent_layers)
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__kwargs__:
args[each] = self.__kwargs__[each]
return conf_helps.StaticInput(**args)
class RecurrentGroupV2(Layer):
def __init__(self, name, **kwargs):
self.__parent_names__ = ['input', 'boot_layer']
other_kwargs = dict()
parent_layers = dict()
for pname in self.__parent_names__:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in self.__parent_names__:
other_kwargs[key] = kwargs[key]
self.__kwargs__ = other_kwargs
super(RecurrentGroupV2, self).__init__(
name=name, parent_layers=parent_layers)
def __init__(self, input=None, **kwargs):
assert input is not None
self.__kwargs__ = kwargs
super(StaticInputV2, self).__init__(
name=input.name, parent_layers={'input': input})
wrapper = wrap_name_default(name_prefix='recurrent_group')
__init__ = wrapper(__init__)
def context_name(self):
return self.name + "#static_input"
def to_proto_impl(self, **kwargs):
def in_args_converter(*in_args):
if not isinstance(in_args, collections.Sequence):
in_args = [in_args]
return [LayerOutputV2(input) for input in in_args]
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__kwargs__:
args[each] = self.__kwargs__[each]
return conf_helps.recurrent_group(
name=self.name, in_args_converter=in_args_converter, **args)
args.update(kwargs)
args.update(self.__kwargs__)
return conf_helps.StaticInput(**args)
class MixedLayerV2(Layer):
......@@ -377,11 +350,79 @@ def mixed(size=0,
return MixedLayerV2(size, input, name, act, bias_attr, layer_attr)
class RecurrentLayerInput(Layer):
def __init__(self, recurrent_name, index, parent_layers):
assert len(parent_layers) == 1
self.__parents__ = parent_layers.values()[0]
print self.__parents__, parent_layers
super(RecurrentLayerInput, self).__init__(
name=self.__parents__[index].name, parent_layers=parent_layers)
self.__recurrent_name__ = recurrent_name
def context_name(self):
return self.__recurrent_name__ + ".begin"
def to_proto_impl(self, **kwargs):
model_type('recurrent_nn')
RecurrentLayerGroupWithoutOutLinksBegin(
name=self.__recurrent_name__,
in_links=map(lambda x: x.name, self.__parents__))
return self
class RecurrentLayerOutput(Layer):
def __init__(self, recurrent_name, index, parent_layers):
assert len(parent_layers) == 1
self.__parents__ = parent_layers.values()[0]
super(RecurrentLayerOutput, self).__init__(
name=self.__parents__[index].name, parent_layers=parent_layers)
self.__recurrent_name__ = recurrent_name
def context_name(self):
return self.__recurrent_name__ + ".end"
def to_proto_impl(self, **kwargs):
for l in self.__parents__:
RecurrentLayerGroupSetOutLink(l.name)
RecurrentLayerGroupEnd(name=self.__recurrent_name__)
@wrap_name_default()
def recurrent_group(step, input, name=None):
if not isinstance(input, collections.Sequence):
input = [input]
actual_input = [
RecurrentLayerInput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_inputs': input})
for i in xrange(len(input))
]
actual_output = step(*actual_input)
if not isinstance(actual_output, collections.Sequence):
actual_output = [actual_output]
retv = [
RecurrentLayerOutput(
recurrent_name=name,
index=i,
parent_layers={'recurrent_outputs': actual_output})
for i in xrange(len(actual_output))
]
if len(retv) == 1:
return retv[0]
else:
return retv
LayerV2 = Layer
data = DataLayerV2
AggregateLevel = conf_helps.layers.AggregateLevel
ExpandLevel = conf_helps.layers.ExpandLevel
recurrent_group = RecurrentGroupV2
recurrent_group = recurrent_group
memory = MemoryV2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册