提交 9d12ca95 编写于 作者: X xuwei06

Further fix the memory for Hierarchical RNN

Sequences should be sorted according to the number of subsequences they have.
上级 05a97ab5
......@@ -21,7 +21,6 @@ from .evaluators import *
from .poolings import MaxPooling, AvgPooling, BasePoolingType
from .attrs import *
from .default_decorators import *
try:
import cPickle as pickle
except ImportError:
......@@ -204,6 +203,25 @@ ERROR_CLIPPING = 'error_clipping_threshold'
DROPOUT = 'drop_rate'
def check_input(input):
"""
Check input is a LayerOutput or list of LayerOutput or tuple of LayerOutput
if is a LayerOutput,
:param input: The input layer. Could be a list/tuple of input layer.
:type input: LayerOutput|list|tuple
:return: list of LayerOutput
:rtype: list of LayerOutput
"""
if isinstance(input, LayerOutput):
return [LayerOutput]
assert isinstance(input, list)
for inp in input:
assert isinstance(inp, LayerOutput)
return list(input)
def layer_support(*attrs):
def decorator(method):
@functools.wraps(method)
......@@ -731,19 +749,27 @@ def fc_layer(input, size, act=None, name=None,
return LayerOutput(name, LayerType.FC_LAYER, input, activation=act,
size=size)
@wrap_name_default("print")
def print_layer(input, name=None):
"""
Print the output value of input layers. This layer is useful for debugging.
:param name: The Layer Name.
:type name: basestring
:param input: The input layer. Could be a list/tuple of input layer.
:type input: LayerOutput|list|tuple
:return: No return
"""
assert isinstance(input, list)
check_input(input)
Layer(
name=name,
type=LayerType.PRINT_LAYER,
inputs=[l.name for l in input],
)
return LayerOutput(name, LayerType.PRINT_LAYER, input)
LayerOutput(name, LayerType.PRINT_LAYER, input)
@wrap_name_default("seq_pooling")
@wrap_bias_attr_default(has_bias=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册