提交 ad4ab5ac 编写于 作者: Q qiaolongfei

remove step_input in recurrent_group step_input

上级 f13f1f1c
......@@ -3042,7 +3042,8 @@ def recurrent_group(step,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
is_generating=False,
in_args_converter=None):
"""
Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a
......@@ -3185,7 +3186,10 @@ def recurrent_group(step,
assert (is_generating != has_LayerOutput)
layer_outs = step(*in_args)
if in_args_converter is None:
layer_outs = step(*in_args)
else:
layer_outs = step(*in_args_converter(*in_args)).to_proto(dict())
if isinstance(layer_outs, LayerOutput):
layer_outs = [layer_outs]
......
......@@ -73,8 +73,6 @@ from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as __parse__
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
import activation
import attr
import data_type
__all__ = [
......@@ -101,11 +99,10 @@ def parse_network(*outputs):
class Layer(object):
def __init__(self, name, parent_layers, step_input=None):
def __init__(self, name, parent_layers):
assert isinstance(parent_layers, dict)
assert isinstance(name, basestring)
self.name = name
self.step_input = step_input
self.__parent_layers__ = parent_layers
def to_proto(self, context):
......@@ -121,12 +118,13 @@ class Layer(object):
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
if layer_name == "input" and self.step_input is not None:
v1_layer.insert(0, self.step_input)
kwargs[layer_name] = v1_layer
if self.name is None:
return self.to_proto_impl(**kwargs)
# memory may have the same name with some layer
if isinstance(self, MemoryV2):
if isinstance(self, MemoryV2) or isinstance(self, LayerOutputV2):
return self.to_proto_impl(**kwargs)
if self.name not in context:
......@@ -144,7 +142,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names):
wrapper = None
class V2LayerImpl(Layer):
def __init__(self, name=None, step_input=None, **kwargs):
def __init__(self, name=None, **kwargs):
parent_layers = dict()
other_kwargs = dict()
for pname in parent_names:
......@@ -155,7 +153,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names):
if key not in parent_names:
other_kwargs[key] = kwargs[key]
super(V2LayerImpl, self).__init__(name, parent_layers, step_input)
super(V2LayerImpl, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs
if wrapper is not None:
......@@ -214,6 +212,48 @@ class MemoryV2(Layer):
return conf_helps.memory(name=self.name, size=self.size, **args)
class LayerOutputV2(Layer):
def __init__(self, layer_output):
assert isinstance(layer_output, conf_helps.LayerOutput)
self.layer_output = layer_output
super(LayerOutputV2, self).__init__(
name=layer_output.name, parent_layers=dict())
def to_proto_impl(self):
return self.layer_output
class RecurrentGroupV2(Layer):
def __init__(self, name, **kwargs):
self.__parent_names__ = ['input']
other_kwargs = dict()
parent_layers = dict()
for pname in self.__parent_names__:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in self.__parent_names__:
other_kwargs[key] = kwargs[key]
self.__kwargs__ = other_kwargs
super(RecurrentGroupV2, self).__init__(
name=name, parent_layers=parent_layers)
def to_proto_impl(self, **kwargs):
def in_args_converter(in_args):
if not isinstance(in_args, collections.Sequence):
in_args = [in_args]
return [LayerOutputV2(input) for input in in_args]
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__kwargs__:
args[each] = self.__kwargs__[each]
return conf_helps.recurrent_group(
name=self.name, in_args_converter=in_args_converter, **args)
data = DataLayerV2
fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input'])
max_id = __convert_to_v2__(
......@@ -234,8 +274,7 @@ embedding = __convert_to_v2__(
'embedding_layer', name_prefix='embedding', parent_names=['input'])
last_seq = __convert_to_v2__(
'last_seq', name_prefix='last_seq', parent_names=['input'])
recurrent_group = __convert_to_v2__(
'recurrent_group', name_prefix='recurrent_layer', parent_names=['input'])
recurrent_group = RecurrentGroupV2
memory = MemoryV2
cross_entropy_with_selfnorm_cost = __convert_to_v2__(
......
......@@ -63,7 +63,7 @@ class RNNTest(unittest.TestCase):
word_dim = 8
hidden_dim = 8
def test_old_rnn():
def parse_old_rnn():
def step(y):
mem = conf_helps.memory(name="rnn_state", size=hidden_dim)
out = conf_helps.fc_layer(
......@@ -81,16 +81,15 @@ class RNNTest(unittest.TestCase):
return str(parse_network(test))
def test_new_rnn():
def parse_new_rnn():
def new_step(y):
mem = layer.memory(name="rnn_state", size=hidden_dim)
out = layer.fc(input=[mem],
step_input=y,
out = layer.fc(input=[y, mem],
size=hidden_dim,
act=activation.Tanh(),
bias_attr=True,
name="rnn_state")
return out.to_proto(dict())
return out
data1 = layer.data(
name="word", type=data_type.integer_value(dict_dim))
......@@ -99,8 +98,8 @@ class RNNTest(unittest.TestCase):
name="rnn", step=new_step, input=embd)
return str(layer.parse_network(rnn_layer))
diff = difflib.unified_diff(test_old_rnn().splitlines(1),
test_new_rnn().splitlines(1))
diff = difflib.unified_diff(parse_old_rnn().splitlines(1),
parse_new_rnn().splitlines(1))
print ''.join(diff)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册