select.py 1.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
""" a custom layer for 'select' which is used to replace standard 'Slice' layer 
    for converting layer with multiple different output tensors
"""
from .register import register


def select_shape(input_shape, slice_point, axis=1):
    """ calculate the output shape of this layer using input shape

    Args:
        @input_shape (list of num): a list of number which represents the input shape
        @slice_point (list): parameter from caffe's Slice layer
        @axis (int): parameter from caffe's Slice layer

    Returns:
        @output_shape (list of num): a list of numbers represent the output shape
    """

    input_shape = list(input_shape)
    start = slice_point[0]
    if len(slice_point) == 2:
        end = slice_point[1]
    else:
        end = input_shape[axis]

    assert end > start, "invalid slice_point with [start:%d, end:%d]"\
             % (start, end)
    output_shape = input_shape
    output_shape[axis] = end - start
    return output_shape


def select_layer(input, name, slice_point, axis=1):
    """ build a layer of type 'Slice' using fluid

    Args:
        @input (variable): input fluid variable for this layer
        @name (str): name for this layer
        @slice_point (list): parameter from caffe's Slice layer
        @axis (int): parameter from caffe's Slice layer

    Returns:
        output (variable): output variable for this layer
    """
    import paddle.fluid as fluid
    input_shape = list(input.shape)

    start = slice_point[0]
    if len(slice_point) == 2:
        end = slice_point[1]
    else:
        end = input_shape[axis]

    sections = []
    if start > 0:
        sections.append(start)

    pos = len(sections)
    sections.append(end - start)
    if end != input_shape[axis]:
        sections.append(input_shape[axis] - end)

    outputs = fluid.layers.split(input, sections, dim=axis, name=name)
    return outputs[pos]


register(kind='Select', shape=select_shape, layer=select_layer)