indexing.oprdecl 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
decl_opr(
    'Diag',
    desc='Extract a diagonal or construct a diagonal array',
    inputs=[
        Doc('k', 'Index of the diagonal: 0 (the default) refers to the main '
            'diagonal, a positive value refers to an upper diagonal, and a '
            'negative value to a lower diagonal.')
    ],
    params='Diag'
)

decl_opr(
    'DiagBackward',
    desc='backward function of Diag',
    inputs=[
        Doc('k', 'Index of the diagonal: 0 (the default) refers to the main '
            'diagonal, a positive value refers to an upper diagonal, and a '
            'negative value to a lower diagonal.')
    ],
    params='Diag'
)

23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
decl_opr('IndexingOneHot', pyname='_indexing_one_hot',
         inputs=['src', 'index'],
         params=[('axis', 'Axis')])

decl_raw_opr(
    'indexing_one_hot',
    inputs=[Doc('src', 'input data, n-dimensional'),
            Doc('axis',
                'the axis on *src* for which values in *index* index', 'int'),
            Doc('index', 'index, (n-1)-dimensional, dtype must be int32_t'),
            Doc('keepdims', 'whether not to remove the axis in result',
                'bool', 'False')],
    body=[
        'output = _indexing_one_hot(src, index, axis=axis, config=config)',
        'if not keepdims:',
        '   output = remove_axis(output, axis)'],
    desc='One-hot indexing for some axis. '
    'If ``keepdims == True``, output data is n-dimensional, but its shape on '
    'indexing axis is 1. Given src, axis and index, for all valid subscript '
    ' tuples i, we have: ``dst[i[0], ..., i[axis-1], 0, i[axis], ..., i[n-1]] '
    '= src[i[0], ..., i[axis-1], index[i], i[axis], ..., i[n-1]]``'
)

decl_opr('IndexingSetOneHot', pyname='_indexing_set_one_hot',
         inputs=['src', 'index', 'value'],
         params=[('axis', 'Axis')])

decl_raw_opr(
    'indexing_set_one_hot',
    inputs=['src', Doc('axis', '', 'int'), 'index', 'value'],
    body=[
        'output = _indexing_set_one_hot(src, index, value, axis=axis, '
                                       'config=config)'
    ],
    desc='set subtensor given by *index* in *src* to *value*; see '
    ':func:`indexing_one_hot` for how the indexing works.')

decl_opr('IndexingRemap',
         inputs=['src', 'map_'],
         params='IndexingRemap',
         desc=Doc(None,
r"""
    Generate an output tensor by treating *map_* as indices to take from *src*.

    Assume shape of *src* is :math:`(s_0, s_1, \ldots, s_{n-1})` and shape of
    *map_* is :math:`(t_0, t_1, \ldots, t_{m-1})`, then the remap requires
    :math:`t_{m-1}=n`, and the output shape would be
    :math:`(t_0, \ldots, t_{m-2})`; for each output element, it has :math:`n`
    corresponding elements in *map_* that would be used as the index to look up
    in *src*.

    .. note::
        This operator accepts a special parameter *is_non_overlapping*; see
        :class:`~.opr_param_defs.IndexingRemap` for its explanation.

"""))

# vim: ft=python