mgb_opr_param_defs.py 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
pdef('DType').add_fields('dtype', 'dtype', 'DTypeEnum::Byte')

pdef('PersistentOutputStorage').add_fields(
    'int32', Doc(
        'share_key',
        'This is used for controlling memory sharing. Multiple '
        "``PersistentOutputStorage'' oprs with the same ``share_key'' "
        "would share underlying tensor storage. Note that the value ``-1'' is "
        'treated specially: storage of oprs with this key would be private and '
        'would not be shared with any other opr.'
    ),
    -1)

(pdef('OptionalAxis', 'optinal axis: axis == -1 means no axis').
 add_fields('int32', 'axis', -1))
(pdef('OptionalAxisV1', 'optinal axis: axis == MAX_NDIM means no axis').
 add_const('int32', 'MAX_NDIM', 7).
 add_const('int32', 'INVALID_AXIS', 'MAX_NDIM').
 add_fields('int32', 'axis', 'INVALID_AXIS'))

21
(pdef('ExecutionPolicy', version=0, is_legacy=True).
22
 add_enum('Strategy',
23 24
          Doc('HEURISTIC = 0', 'use heuristic to choose the fastest algorithm'),
          Doc('HEURISTIC_REPRODUCIBLE = 1', 'use heuristic to choose the fastest algorithm, '
25
              'and the chosen algorithm is reproducible'),
26
          Doc('PROFILE = 2',
27
              'run possible algorithms on real device to find the best'),
28
          Doc('PROFILE_REPRODUCIBLE = 3',
29
              'the fastest of profile result that is also reproducible'),
30
          Doc('PROFILE_HEURISTIC = 4',
31 32 33 34 35
              'use profile result and heuristic to choose the fastest algorithm')).
 add_fields('uint64',
            Doc('workspace_limit', 'workspace limit in bytes'),
            str(2**64-1)+'ull'))

36 37
(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1).
 add_bit_combination_enum('Strategy',
38 39
          Doc('HEURISTIC = 1 << 0', 'use heuristic to choose the fastest algorithm'),
          Doc('PROFILE = 1 << 1',
40
              'run possible algorithms on real device to find the best'),
41
          Doc('REPRODUCIBLE = 1 << 2',
42 43
              'when profile or heuristic algo selection it require the algos'
              'must be reproducible'),
44
          Doc('OPTIMIZED = 1 << 3',
45 46 47 48 49
              'profile require algos are optmized to achieve fast-profile'),
          default=('HEURISTIC',),
          member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
                        (('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'),
                        (('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'),
M
Megvii Engine Team 已提交
50
                        (('OPTIMIZED',), 'OPTMIZED'),
51
                        ]).
52 53 54 55
 add_fields('uint64',
            Doc('workspace_limit', 'workspace limit in bytes'),
            str(2**64-1)+'ull'))

56 57 58 59 60 61 62 63 64 65 66 67
(pdef('AssertEqual').
 add_fields('float32',
            Doc('maxerr', 'max allowed error; error is defined as the minimal '
                'of absolute and relative error'),
            1e-4).
 add_fields('bool', Doc('verbose', 'whether to print maxerr to stdout '
                        'during opr exec'),
            'false')
 )

(pdef('CollectiveComm', 'collective communication between multiple computing '
      'nodes on localhost')
68
 .add_enum(Doc('Mode', 'mode of collective communication'),
69 70 71
           Doc('REDUCE_SUM = 0', 'reduce by sum to output computing node'),
           Doc('BROADCAST = 1', 'copy input value to each output computing node'),
           Doc('ALL_GATHER = 2', 'each output comp node gets the concatenated '
72
               'value of all inputs'),
73
           Doc('REDUCE_SCATTER_SUM = 3',
74
               'reduce inputs by sum and each output gets one part of it'),
75 76 77 78 79 80 81
           Doc('ALL_REDUCE_SUM = 4', 'every output gets the sum of all inputs'),
           Doc('ALL_REDUCE_MAX = 5', 'every output gets the max of all inputs'),
           Doc('ALL_REDUCE_MIN = 6', 'every output gets the min of all inputs'),
           Doc('ALL_REDUCE_PROD = 7', 'every output gets the prod of all inputs'),
           Doc('GATHER = 8', 'concat inputs to one node'),
           Doc('SCATTER = 9', 'scatter input to each output computing node'),
           Doc('ALL_TO_ALL = 10', 'scatter inputs and gather them on each computing node'),
82
           name_field='mode'))
83 84 85 86 87 88 89 90 91 92 93

(pdef('FakeSerializedDType',
      'HACK: The tag of this param def is actually used for another '
      'non-generated param def SerializedDType, the sole purpose of this param '
      'def is to provide a spare tag. Do not use.'
))

(pdef('CondExecPred',
      'evaluate a predicate and branch keys to setup ExecutionMask objects '
      'with associated predicate proxy vars (PPVs)')
 .add_enum(Doc('Mode', 'how to compare predicate var with branch keys'),
94
           Doc('CASE = 0',
95 96 97
               'The outputs correspond to branch keys, '
               'and the one which equals predicate would be activated. '
               'This behaves like a case-statement in many languages.'),
98
           Doc('CASE_FALLBACK = 1', 'like :attr:`CASE`, but add an extra output '
99
               'that would be activated if no branch is matched'),
100
           Doc('PIECEWISE = 2', 'One more outputs would be produced than the '
101 102 103 104 105 106 107 108 109 110 111 112 113 114
               'number of branch keys, representing the interval in which the '
               'predicate var fits in. The intervals are defined as '
               r':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
               '[k_{n-2}, k_{n-1}), [k_{n-1}, \\infty)`. '
               'The keys must be given in ascending order.')
           )
 .add_fields('float32',
             Doc('eps',
                 'threshold for checking equality of float point values'),
             1e-4)
 )

(pdef('CondExecPredLogical',
      'compute a logical function over a set of PPVs')
115 116 117 118 119 120
 .add_enum('Mode', Doc('OR = 0', 'logical or'),
           Doc('AND = 1', 'logical and'),
           Doc('XOR = 2', 'exclusive-or'),
           Doc('NOR = 3', 'not or(inputs)'),
           Doc('NAND = 4', 'not and(inputs)'),
           Doc('XNOR = 5', 'not xor(inputs)'))
121 122 123 124 125 126
 )

(pdef('CondExecMark',
      'add ExecutionMask of the input PPV to this opr and readers of the '
      'outputs of this opr')
 .add_enum(Doc('GradMode', 'mode for computing the gradient'),
127 128
           Doc('SUM = 0', 'normal gradient mode: sum all the activated components'),
           Doc('SUM_COND_OUT = 1', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
129 130 131 132 133 134 135 136 137
               'oprs that depend on the gradient opr would not be executed '
               'if the forward var is not used.'),
           name_field='grad_mode')
 .add_enum(Doc('StaticInfer',
               """static inference option. **Note:** This is a workaround: since
               currently static inference in MegBrain does not take conditional
               execution into account, this option can be used to bypass static
               inference errors. This is currently only used by automatically
               generated gradient oprs."""),
138 139
           Doc('SHAPE_VALUE = 0', 'enable both shape and value inference'),
           Doc('SHAPE_ONLY = 1',
140
               'only enable shape inference (disable value inference)'),
141
           Doc('NONE = 2', 'disable both shape and value inference'),
142 143 144 145 146 147 148 149
           name_field='static_infer')
 )

(pdef('CondExecMerge', 'merge multiple conditional execution branches')
 .add_fields('uint32', Doc('nr_output',
                           'number of output vars (i.e. vars per branch)'),
             1)
 .add_enum('Mode',
150
           Doc('EXACT_ONE = 0', 'copy the var whose mask is activated to the output'
151
               ', requiring that exactly one branch is active'),
152
           Doc('EXACT_ONE_SAME_SHAPE = 1', 'like :attr:`EXACT_ONE` with the '
153 154
               'requirement that all branches have the same shape, so shape '
               'inference can be easier'),
155
           Doc('SUM = 2', 'sum all the active branches into output var; require '
156 157 158 159
               'all branches to have the same shape. Extra shape vars are '
               'needed in this mod, so the outputs can be initialized to zero '
               'when no input is active (and their shapes are probably '
               'unknown).'),
160
           Doc('SUM_COND_OUT = 3', 'like :attr:`SUM` but also add an ExecutionMask'
161 162 163 164
               ' to the readers of output vars, so they would be skipped if '
               ' no branch is taken')
           )
 )
165

166
(pdef('NvOf', 'opr Implements NVIDIA Optical Flow SDK.').add_fields('uint32', 'precision', 1))