misc.oprdecl 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
axis_inp = Doc(
    'axis',
    'axis along which to reduce input var; if it is None, '
    'the input would be flattened',
    'int or None',
    'None')
keepdims_inp = Doc(
    'keepdims',
    'If True, the given axis would have be shape 1 in the output; otherwise '
    'it would removed',
    'bool',
    'False')
call_reduce_like = lambda impl: [
    'output = _reduce_like({}, src, axis, keepdims, name, comp_node, config, '
    'comp_graph)'.format(impl)]

decl_opr('Argmax', pyname='_argmax',
         inputs=['src'],
         params=[('axis', 'Axis')])

decl_opr('Argmin', pyname='_argmin',
         inputs=['src'],
         params=[('axis', 'Axis')])

decl_raw_opr(
    'argmax',
    desc='Returns the indices of the maximum values along an axis.',
    inputs=['src', axis_inp, keepdims_inp],
    body=call_reduce_like('_argmax'))

decl_raw_opr(
    'argmin',
    desc='Returns the indices of the minimum values along an axis.',
    inputs=['src', axis_inp, keepdims_inp],
    body=call_reduce_like('_argmin'))

decl_opr('Argsort',
         inputs=['src'],
         params='Argsort',
         desc='The input must be an :math:`(m, n)` matrix. and this operator '
         'sorts each row independently, so :math:`m` independent sortings are '
         'performed. Two vars are returned: the sorted array, and the '
         'indices. ')

decl_opr('Cumsum',
         inputs=['src'], params='Cumsum',
         body=[
             'if param.axis == (1<<31)-1:',
             '    all_inputs[0] = all_inputs[0].flatten()',
             '    param.axis = 0'
         ],
         desc='Return the cumulative sum of the elements along a given axis.'
         '  If axis is INT_MAX, compute on flattened input.', version=1)

decl_opr('CondTake',
         inputs=['data', 'mask'], params='CondTake',
         desc='Take elements from *data* according to *mask* and *param*. '
         'This operator has two outputs, both 1-dimensional: the first is '
         'the element values, and the second is corresponding offsets of the '
         'taken values')

decl_opr('TopK',
         inputs=['data', 'k'], params='TopK',
         desc='Select the top k values from sorted result.')

# vim: ft=python