提交 9d12ca95 编写于 作者: X xuwei06

Further fix the memory for Hierarchical RNN

Sequences should be sorted according to the number of subsequences they have.
上级 05a97ab5
...@@ -21,7 +21,6 @@ from .evaluators import * ...@@ -21,7 +21,6 @@ from .evaluators import *
from .poolings import MaxPooling, AvgPooling, BasePoolingType from .poolings import MaxPooling, AvgPooling, BasePoolingType
from .attrs import * from .attrs import *
from .default_decorators import * from .default_decorators import *
try: try:
import cPickle as pickle import cPickle as pickle
except ImportError: except ImportError:
...@@ -204,6 +203,25 @@ ERROR_CLIPPING = 'error_clipping_threshold' ...@@ -204,6 +203,25 @@ ERROR_CLIPPING = 'error_clipping_threshold'
DROPOUT = 'drop_rate' DROPOUT = 'drop_rate'
def check_input(input):
"""
Check input is a LayerOutput or list of LayerOutput or tuple of LayerOutput
if is a LayerOutput,
:param input: The input layer. Could be a list/tuple of input layer.
:type input: LayerOutput|list|tuple
:return: list of LayerOutput
:rtype: list of LayerOutput
"""
if isinstance(input, LayerOutput):
return [LayerOutput]
assert isinstance(input, list)
for inp in input:
assert isinstance(inp, LayerOutput)
return list(input)
def layer_support(*attrs): def layer_support(*attrs):
def decorator(method): def decorator(method):
@functools.wraps(method) @functools.wraps(method)
...@@ -731,19 +749,27 @@ def fc_layer(input, size, act=None, name=None, ...@@ -731,19 +749,27 @@ def fc_layer(input, size, act=None, name=None,
return LayerOutput(name, LayerType.FC_LAYER, input, activation=act, return LayerOutput(name, LayerType.FC_LAYER, input, activation=act,
size=size) size=size)
@wrap_name_default("print") @wrap_name_default("print")
def print_layer(input, name=None): def print_layer(input, name=None):
""" """
Print the output value of input layers. This layer is useful for debugging. Print the output value of input layers. This layer is useful for debugging.
:param name: The Layer Name.
:type name: basestring
:param input: The input layer. Could be a list/tuple of input layer.
:type input: LayerOutput|list|tuple
:return: No return
""" """
assert isinstance(input, list) check_input(input)
Layer( Layer(
name=name, name=name,
type=LayerType.PRINT_LAYER, type=LayerType.PRINT_LAYER,
inputs=[l.name for l in input], inputs=[l.name for l in input],
) )
return LayerOutput(name, LayerType.PRINT_LAYER, input) LayerOutput(name, LayerType.PRINT_LAYER, input)
@wrap_name_default("seq_pooling") @wrap_name_default("seq_pooling")
@wrap_bias_attr_default(has_bias=False) @wrap_bias_attr_default(has_bias=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册