flatten.py 2.1 KB
Newer Older
R
Renwb1991 已提交
1 2 3 4
""" a custom layer for 'flatten', maybe we should implement this in standard way.
    more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/flatten.html
"""
from .register import register
S
SunAhong 已提交
5
from functools import reduce
R
Renwb1991 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

def flatten_shape(input_shape, axis=1, end_axis=-1):
    """ calculate the output shape of this layer using input shape

    Args:
        @input_shape (list of num): a list of number which represents the input shape
        @axis (int): parameter from caffe's Flatten layer
        @end_axis (int): parameter from caffe's Flatten layer

    Returns:
        @output_shape (list of num): a list of numbers represent the output shape
    """

    start_axis = axis
    end_axis = end_axis
    input_shape = list(input_shape)
    if start_axis < 0:
        start_axis += len(input_shape)

    if end_axis < 0:
        end_axis += len(input_shape) + 1

    assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
            % (start_axis, end_axis)
    output_shape = input_shape[0:start_axis]
    flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis])
S
SunAhong 已提交
32 33
    if flat_sz < 0:
        flat_sz = -1
R
Renwb1991 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    output_shape += [flat_sz]
    output_shape += input_shape[end_axis:-1]

    return output_shape


def flatten_layer(input, name, axis=1, end_axis=-1):
    """ build a layer of type 'Flatten' using fluid

    Args:
        @input (variable): input fluid variable for this layer
        @name (str): name for this layer
        @axis (int): parameter from caffe's Flatten layer
        @end_axis (int): parameter from caffe's Flatten layer

    Returns:
        output (variable): output variable for this layer
    """
    import paddle.fluid as fluid

    input_shape = list(input.shape)

    if input_shape[0] == -1:
S
SunAhong 已提交
57
        input_shape[0] = 0
R
Renwb1991 已提交
58 59 60 61 62 63 64 65 66 67
        output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis)
    else:
        output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis)

    output = fluid.layers.reshape(input, shape=output_shape, name=name)

    return output


register(kind='Flatten', shape=flatten_shape, layer=flatten_layer)