diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index dfcabfdb8430b334b916ee6db7a3bdf76b83fb91..bda0b4f5d60e82c1d577b0063fd5e164bf6117c3 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -21,7 +21,6 @@ from .evaluators import * from .poolings import MaxPooling, AvgPooling, BasePoolingType from .attrs import * from .default_decorators import * - try: import cPickle as pickle except ImportError: @@ -204,6 +203,25 @@ ERROR_CLIPPING = 'error_clipping_threshold' DROPOUT = 'drop_rate' +def check_input(input): + """ + Check input is a LayerOutput or list of LayerOutput or tuple of LayerOutput + if is a LayerOutput, + + :param input: The input layer. Could be a list/tuple of input layer. + :type input: LayerOutput|list|tuple + :return: list of LayerOutput + :rtype: list of LayerOutput + """ + + if isinstance(input, LayerOutput): + return [LayerOutput] + assert isinstance(input, list) + for inp in input: + assert isinstance(inp, LayerOutput) + return list(input) + + def layer_support(*attrs): def decorator(method): @functools.wraps(method) @@ -731,19 +749,27 @@ def fc_layer(input, size, act=None, name=None, return LayerOutput(name, LayerType.FC_LAYER, input, activation=act, size=size) + @wrap_name_default("print") def print_layer(input, name=None): """ Print the output value of input layers. This layer is useful for debugging. + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. Could be a list/tuple of input layer. + :type input: LayerOutput|list|tuple + :return: No return """ - assert isinstance(input, list) + check_input(input) Layer( name=name, type=LayerType.PRINT_LAYER, inputs=[l.name for l in input], ) - return LayerOutput(name, LayerType.PRINT_LAYER, input) + LayerOutput(name, LayerType.PRINT_LAYER, input) + @wrap_name_default("seq_pooling") @wrap_bias_attr_default(has_bias=False)