提交 ad4ab5ac 编写于 作者: Q qiaolongfei

remove step_input in recurrent_group step_input

上级 f13f1f1c
...@@ -3042,7 +3042,8 @@ def recurrent_group(step, ...@@ -3042,7 +3042,8 @@ def recurrent_group(step,
reverse=False, reverse=False,
name=None, name=None,
targetInlink=None, targetInlink=None,
is_generating=False): is_generating=False,
in_args_converter=None):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -3185,7 +3186,10 @@ def recurrent_group(step, ...@@ -3185,7 +3186,10 @@ def recurrent_group(step,
assert (is_generating != has_LayerOutput) assert (is_generating != has_LayerOutput)
layer_outs = step(*in_args) if in_args_converter is None:
layer_outs = step(*in_args)
else:
layer_outs = step(*in_args_converter(*in_args)).to_proto(dict())
if isinstance(layer_outs, LayerOutput): if isinstance(layer_outs, LayerOutput):
layer_outs = [layer_outs] layer_outs = [layer_outs]
......
...@@ -73,8 +73,6 @@ from paddle.trainer_config_helpers.config_parser_utils import \ ...@@ -73,8 +73,6 @@ from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as __parse__ parse_network_config as __parse__
from paddle.trainer_config_helpers.default_decorators import wrap_name_default from paddle.trainer_config_helpers.default_decorators import wrap_name_default
import activation
import attr
import data_type import data_type
__all__ = [ __all__ = [
...@@ -101,11 +99,10 @@ def parse_network(*outputs): ...@@ -101,11 +99,10 @@ def parse_network(*outputs):
class Layer(object): class Layer(object):
def __init__(self, name, parent_layers, step_input=None): def __init__(self, name, parent_layers):
assert isinstance(parent_layers, dict) assert isinstance(parent_layers, dict)
assert isinstance(name, basestring) assert isinstance(name, basestring)
self.name = name self.name = name
self.step_input = step_input
self.__parent_layers__ = parent_layers self.__parent_layers__ = parent_layers
def to_proto(self, context): def to_proto(self, context):
...@@ -121,12 +118,13 @@ class Layer(object): ...@@ -121,12 +118,13 @@ class Layer(object):
else: else:
v1_layer = map(lambda x: x.to_proto(context=context), v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name]) self.__parent_layers__[layer_name])
if layer_name == "input" and self.step_input is not None:
v1_layer.insert(0, self.step_input)
kwargs[layer_name] = v1_layer kwargs[layer_name] = v1_layer
if self.name is None:
return self.to_proto_impl(**kwargs)
# memory may have the same name with some layer # memory may have the same name with some layer
if isinstance(self, MemoryV2): if isinstance(self, MemoryV2) or isinstance(self, LayerOutputV2):
return self.to_proto_impl(**kwargs) return self.to_proto_impl(**kwargs)
if self.name not in context: if self.name not in context:
...@@ -144,7 +142,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): ...@@ -144,7 +142,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names):
wrapper = None wrapper = None
class V2LayerImpl(Layer): class V2LayerImpl(Layer):
def __init__(self, name=None, step_input=None, **kwargs): def __init__(self, name=None, **kwargs):
parent_layers = dict() parent_layers = dict()
other_kwargs = dict() other_kwargs = dict()
for pname in parent_names: for pname in parent_names:
...@@ -155,7 +153,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): ...@@ -155,7 +153,7 @@ def __convert_to_v2__(method_name, name_prefix, parent_names):
if key not in parent_names: if key not in parent_names:
other_kwargs[key] = kwargs[key] other_kwargs[key] = kwargs[key]
super(V2LayerImpl, self).__init__(name, parent_layers, step_input) super(V2LayerImpl, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs self.__other_kwargs__ = other_kwargs
if wrapper is not None: if wrapper is not None:
...@@ -214,6 +212,48 @@ class MemoryV2(Layer): ...@@ -214,6 +212,48 @@ class MemoryV2(Layer):
return conf_helps.memory(name=self.name, size=self.size, **args) return conf_helps.memory(name=self.name, size=self.size, **args)
class LayerOutputV2(Layer):
def __init__(self, layer_output):
assert isinstance(layer_output, conf_helps.LayerOutput)
self.layer_output = layer_output
super(LayerOutputV2, self).__init__(
name=layer_output.name, parent_layers=dict())
def to_proto_impl(self):
return self.layer_output
class RecurrentGroupV2(Layer):
def __init__(self, name, **kwargs):
self.__parent_names__ = ['input']
other_kwargs = dict()
parent_layers = dict()
for pname in self.__parent_names__:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in self.__parent_names__:
other_kwargs[key] = kwargs[key]
self.__kwargs__ = other_kwargs
super(RecurrentGroupV2, self).__init__(
name=name, parent_layers=parent_layers)
def to_proto_impl(self, **kwargs):
def in_args_converter(in_args):
if not isinstance(in_args, collections.Sequence):
in_args = [in_args]
return [LayerOutputV2(input) for input in in_args]
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__kwargs__:
args[each] = self.__kwargs__[each]
return conf_helps.recurrent_group(
name=self.name, in_args_converter=in_args_converter, **args)
data = DataLayerV2 data = DataLayerV2
fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input']) fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input'])
max_id = __convert_to_v2__( max_id = __convert_to_v2__(
...@@ -234,8 +274,7 @@ embedding = __convert_to_v2__( ...@@ -234,8 +274,7 @@ embedding = __convert_to_v2__(
'embedding_layer', name_prefix='embedding', parent_names=['input']) 'embedding_layer', name_prefix='embedding', parent_names=['input'])
last_seq = __convert_to_v2__( last_seq = __convert_to_v2__(
'last_seq', name_prefix='last_seq', parent_names=['input']) 'last_seq', name_prefix='last_seq', parent_names=['input'])
recurrent_group = __convert_to_v2__( recurrent_group = RecurrentGroupV2
'recurrent_group', name_prefix='recurrent_layer', parent_names=['input'])
memory = MemoryV2 memory = MemoryV2
cross_entropy_with_selfnorm_cost = __convert_to_v2__( cross_entropy_with_selfnorm_cost = __convert_to_v2__(
......
...@@ -63,7 +63,7 @@ class RNNTest(unittest.TestCase): ...@@ -63,7 +63,7 @@ class RNNTest(unittest.TestCase):
word_dim = 8 word_dim = 8
hidden_dim = 8 hidden_dim = 8
def test_old_rnn(): def parse_old_rnn():
def step(y): def step(y):
mem = conf_helps.memory(name="rnn_state", size=hidden_dim) mem = conf_helps.memory(name="rnn_state", size=hidden_dim)
out = conf_helps.fc_layer( out = conf_helps.fc_layer(
...@@ -81,16 +81,15 @@ class RNNTest(unittest.TestCase): ...@@ -81,16 +81,15 @@ class RNNTest(unittest.TestCase):
return str(parse_network(test)) return str(parse_network(test))
def test_new_rnn(): def parse_new_rnn():
def new_step(y): def new_step(y):
mem = layer.memory(name="rnn_state", size=hidden_dim) mem = layer.memory(name="rnn_state", size=hidden_dim)
out = layer.fc(input=[mem], out = layer.fc(input=[y, mem],
step_input=y,
size=hidden_dim, size=hidden_dim,
act=activation.Tanh(), act=activation.Tanh(),
bias_attr=True, bias_attr=True,
name="rnn_state") name="rnn_state")
return out.to_proto(dict()) return out
data1 = layer.data( data1 = layer.data(
name="word", type=data_type.integer_value(dict_dim)) name="word", type=data_type.integer_value(dict_dim))
...@@ -99,8 +98,8 @@ class RNNTest(unittest.TestCase): ...@@ -99,8 +98,8 @@ class RNNTest(unittest.TestCase):
name="rnn", step=new_step, input=embd) name="rnn", step=new_step, input=embd)
return str(layer.parse_network(rnn_layer)) return str(layer.parse_network(rnn_layer))
diff = difflib.unified_diff(test_old_rnn().splitlines(1), diff = difflib.unified_diff(parse_old_rnn().splitlines(1),
test_new_rnn().splitlines(1)) parse_new_rnn().splitlines(1))
print ''.join(diff) print ''.join(diff)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册