提交 7ad83630 编写于 作者: Q qiaolongfei

support boot_layer

上级 6b1a91f9
...@@ -3110,7 +3110,8 @@ def recurrent_group(step, ...@@ -3110,7 +3110,8 @@ def recurrent_group(step,
name=None, name=None,
targetInlink=None, targetInlink=None,
is_generating=False, is_generating=False,
in_args_converter=None): in_args_converter=None,
boot_layer=None):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -3256,6 +3257,9 @@ def recurrent_group(step, ...@@ -3256,6 +3257,9 @@ def recurrent_group(step,
if in_args_converter is None: if in_args_converter is None:
layer_outs = step(*in_args) layer_outs = step(*in_args)
else: else:
# append boot_layer to the last of in_args
if boot_layer is not None:
in_args.append(boot_layer)
layer_outs = step(*in_args_converter(*in_args)).to_proto(dict()) layer_outs = step(*in_args_converter(*in_args)).to_proto(dict())
if isinstance(layer_outs, LayerOutput): if isinstance(layer_outs, LayerOutput):
......
...@@ -140,10 +140,13 @@ class Layer(object): ...@@ -140,10 +140,13 @@ class Layer(object):
if self.name is None: if self.name is None:
return self.to_proto_impl(**kwargs) return self.to_proto_impl(**kwargs)
elif isinstance(self, MemoryV2): elif isinstance(self, MemoryV2):
return self.to_proto_impl(**kwargs) name = self.name + "#__memory__"
elif self.name not in context: if name not in context:
context[self.name] = self.to_proto_impl(**kwargs) context[name] = self.to_proto_impl(**kwargs)
return context[name]
if self.name not in context:
context[self.name] = self.to_proto_impl(**kwargs)
return context[self.name] return context[self.name]
def to_proto_impl(self, **kwargs): def to_proto_impl(self, **kwargs):
...@@ -256,9 +259,32 @@ class LayerOutputV2(Layer): ...@@ -256,9 +259,32 @@ class LayerOutputV2(Layer):
return self.layer_output return self.layer_output
class StaticInputV2(Layer):
def __init__(self, **kwargs):
self.__parent_names__ = ['input']
other_kwargs = dict()
parent_layers = dict()
for pname in self.__parent_names__:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in self.__parent_names__:
other_kwargs[key] = kwargs[key]
self.__kwargs__ = other_kwargs
super(StaticInputV2, self).__init__(parent_layers=parent_layers)
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__kwargs__:
args[each] = self.__kwargs__[each]
return conf_helps.StaticInput(**args)
class RecurrentGroupV2(Layer): class RecurrentGroupV2(Layer):
def __init__(self, name, **kwargs): def __init__(self, name, **kwargs):
self.__parent_names__ = ['input'] self.__parent_names__ = ['input', 'boot_layer']
other_kwargs = dict() other_kwargs = dict()
parent_layers = dict() parent_layers = dict()
for pname in self.__parent_names__: for pname in self.__parent_names__:
...@@ -443,7 +469,8 @@ layer_list = [ ...@@ -443,7 +469,8 @@ layer_list = [
['nce', 'nce_layer', ['input', 'label']], ['nce', 'nce_layer', ['input', 'label']],
['hsigmoid', 'hsigmoid', ['input', 'label']], ['hsigmoid', 'hsigmoid', ['input', 'label']],
# check layers # check layers
['eos', 'eos_layer', ['input']] ['eos', 'eos_layer', ['input']],
['gru_step_layer', 'gru_step_layer', ['input', 'output_mem']]
] ]
for l in layer_list: for l in layer_list:
globals()[l[0]] = __convert_to_v2__(l[1], l[2]) globals()[l[0]] = __convert_to_v2__(l[1], l[2])
......
...@@ -10,7 +10,6 @@ add_test(NAME test_v2_rnn_layer ...@@ -10,7 +10,6 @@ add_test(NAME test_v2_rnn_layer
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_rnn_layer.py) ${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_rnn_layer.py)
add_test(NAME test_topology add_test(NAME test_topology
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_topology.py ${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_topology.py
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册