flatten.py 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
""" a custom layer for 'flatten', maybe we should implement this in standard way.
    more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/flatten.html
"""
from .register import register


def flatten_shape(input_shape, axis=1, end_axis=-1):
    """ calculate the output shape of this layer using input shape

    Args:
        @input_shape (list of num): a list of number which represents the input shape
        @axis (int): parameter from caffe's Flatten layer
        @end_axis (int): parameter from caffe's Flatten layer

    Returns:
        @output_shape (list of num): a list of numbers represent the output shape
    """

    start_axis = axis
    end_axis = end_axis
    input_shape = list(input_shape)
    if start_axis < 0:
        start_axis += len(input_shape)

    if end_axis < 0:
26
        end_axis += len(input_shape) + 1
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

    assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
            % (start_axis, end_axis)
    output_shape = input_shape[0:start_axis]
    flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis])
    output_shape += [flat_sz]
    output_shape += input_shape[end_axis:-1]

    return output_shape


def flatten_layer(input, name, axis=1, end_axis=-1):
    """ build a layer of type 'Flatten' using fluid

    Args:
        @input (variable): input fluid variable for this layer
        @name (str): name for this layer
        @axis (int): parameter from caffe's Flatten layer
        @end_axis (int): parameter from caffe's Flatten layer

    Returns:
        output (variable): output variable for this layer
    """
50
    import paddle.fluid as fluid
51 52 53

    input_shape = list(input.shape)

54 55 56 57 58 59
    if input_shape[0] == -1:
        input_shape[0] = 1
        output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis)
        output_shape[0] = -1
    else:
        output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis)
60 61 62 63 64 65 66

    output = fluid.layers.reshape(input, shape=output_shape, name=name)

    return output


register(kind='Flatten', shape=flatten_shape, layer=flatten_layer)