提交 7ad83630 编写于 作者: Q qiaolongfei

support boot_layer

上级 6b1a91f9
......@@ -3110,7 +3110,8 @@ def recurrent_group(step,
name=None,
targetInlink=None,
is_generating=False,
in_args_converter=None):
in_args_converter=None,
boot_layer=None):
"""
Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a
......@@ -3256,6 +3257,9 @@ def recurrent_group(step,
if in_args_converter is None:
layer_outs = step(*in_args)
else:
# append boot_layer to the last of in_args
if boot_layer is not None:
in_args.append(boot_layer)
layer_outs = step(*in_args_converter(*in_args)).to_proto(dict())
if isinstance(layer_outs, LayerOutput):
......
......@@ -140,10 +140,13 @@ class Layer(object):
if self.name is None:
return self.to_proto_impl(**kwargs)
elif isinstance(self, MemoryV2):
return self.to_proto_impl(**kwargs)
elif self.name not in context:
context[self.name] = self.to_proto_impl(**kwargs)
name = self.name + "#__memory__"
if name not in context:
context[name] = self.to_proto_impl(**kwargs)
return context[name]
if self.name not in context:
context[self.name] = self.to_proto_impl(**kwargs)
return context[self.name]
def to_proto_impl(self, **kwargs):
......@@ -256,9 +259,32 @@ class LayerOutputV2(Layer):
return self.layer_output
class StaticInputV2(Layer):
def __init__(self, **kwargs):
self.__parent_names__ = ['input']
other_kwargs = dict()
parent_layers = dict()
for pname in self.__parent_names__:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in self.__parent_names__:
other_kwargs[key] = kwargs[key]
self.__kwargs__ = other_kwargs
super(StaticInputV2, self).__init__(parent_layers=parent_layers)
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__kwargs__:
args[each] = self.__kwargs__[each]
return conf_helps.StaticInput(**args)
class RecurrentGroupV2(Layer):
def __init__(self, name, **kwargs):
self.__parent_names__ = ['input']
self.__parent_names__ = ['input', 'boot_layer']
other_kwargs = dict()
parent_layers = dict()
for pname in self.__parent_names__:
......@@ -443,7 +469,8 @@ layer_list = [
['nce', 'nce_layer', ['input', 'label']],
['hsigmoid', 'hsigmoid', ['input', 'label']],
# check layers
['eos', 'eos_layer', ['input']]
['eos', 'eos_layer', ['input']],
['gru_step_layer', 'gru_step_layer', ['input', 'output_mem']]
]
for l in layer_list:
globals()[l[0]] = __convert_to_v2__(l[1], l[2])
......
......@@ -10,7 +10,6 @@ add_test(NAME test_v2_rnn_layer
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_rnn_layer.py)
add_test(NAME test_topology
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_topology.py
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册