ops.td 13.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
#ifndef MGB_OPS
#define MGB_OPS

include "base.td"
include "param_defs.td"

include "mlir/Interfaces/SideEffectInterfaces.td"

def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
  let inputs = (ins Variadic<AnyType>:$input);
  let results = (outs AnyType);
12 13 14
  let nameFunction = [{
    return to_string($_self.mode);
  }];
15 16
}

17 18 19 20 21
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{
  let extraArguments = (ins
    MgbBoolAttr:$keepdim
  );
}
22 23 24 25 26 27 28 29 30 31

def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
  let inputs = (ins AnyType:$inputs);
  let extraArguments = (ins
    TypeAttr:$idtype,
    MgbDTypeAttr:$dtype
  );
  let results = (outs AnyType);
}

32 33
def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;

34 35 36 37 38 39
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
  let extraArguments = (ins
    MgbUI32Attr:$dimA,
    MgbUI32Attr:$dimB
  );
}
40

41 42 43 44 45 46
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
  let extraArguments = (ins
    MgbUI32Attr:$dimA,
    MgbUI32Attr:$dimB
  );
}
47 48 49 50 51 52 53

def Dot: MgbHashableOp<"Dot", [EmptyParam]>;

def SVD: MgbHashableOp<"SVD", [SVDParam]>;

def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;

54 55 56 57 58
def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]> {
  let extraArguments = (ins
    MgbDTypeAttr:$dtype
  );
}
59

60 61
def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;

62 63
def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;

64 65
def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;

66 67
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;

68
def Pooling: MgbHashableOp<"Pooling", [PoolingParam, ExecutionPolicyParamBase<"policy">]>;
69

70 71 72 73 74
def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]> {
  let extraArguments = (ins
    MgbArrayAttr<MgbI32Attr>:$shape
  );
}
75 76 77

def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>;

78 79
def DeformablePSROIPooling : MgbHashableOp<"DeformablePSROIPooling", [DeformablePSROIPoolingParam]>;

80 81 82 83 84 85 86 87 88 89 90 91
def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  let extraArguments = (ins
    MgbDTypeAttr:$dtype
  );
}

def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  let extraArguments = (ins
    MgbDTypeAttr:$dtype
  );
}

92 93
def Images2Neibs : MgbHashableOp<"Images2Neibs", [Images2NeibsParam]>;

94 95
def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWindowTransposeParam]>;

96 97
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;

98 99
def BatchNormBackward : MgbHashableOp<"BatchNormBackward", [BNParam]>;

100
def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;
101
def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;
102 103 104

def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;

105 106
def WarpAffine: MgbHashableOp<"WarpAffine", [WarpAffineParam]>;

107 108
def Remap: MgbHashableOp<"Remap", [RemapParam]>;

109 110
def Resize: MgbHashableOp<"Resize", [ResizeParam]>;

111 112 113 114 115 116 117 118 119 120
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;

def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;

def Copy: MgbHashableOp<"Copy"> {
  let extraArguments = (ins
    MgbCompNodeAttr:$comp_node
  );
}

121 122 123 124 125 126 127 128 129 130 131 132 133
def Borrow: MgbHashableOp<"Borrow"> {
  let extraArguments = (ins
    MgbCompNodeAttr:$comp_node
  );
}

def Barrier: MgbHashableOp<"Barrier"> {
  let extraArguments = (ins
    MgbCompNodeAttr:$comp_node,
    MgbUI32Attr:$nr_outputs
  );
}

134 135 136 137 138 139 140 141 142 143 144 145 146
def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>;

def Argmax : MgbHashableOp<"Argmax", [AxisParam]>;

def Argmin : MgbHashableOp<"Argmin", [AxisParam]>;

def CondTake : MgbHashableOp<"CondTake">;

def TopK: MgbHashableOp<"TopK", [TopKParam]>;

def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;

def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
147 148 149 150 151 152
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
153 154 155 156 157
      mgb::hash_pair_combine(
        mgb::hash($_self.handle),
        mgb::hash($_self.dtype.enumv())
      )
    );
158
  }];
159
  let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
160 161 162
}

def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
163 164 165
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
166 167 168
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
169 170 171 172
      mgb::hash_pair_combine(
        mgb::hash($_self.handle),
        mgb::hash_pair_combine(
          mgb::hash($_self.mean),
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
          mgb::hash_pair_combine(
            mgb::hash($_self.std),
            mgb::hash($_self.dtype.enumv())
          )
        )
      )
    );
  }];
  let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std && $0.dtype == $1.dtype;}];
}

def GammaRNG: MgbHashableOp<"GammaRNG", [GammaRNGParam]> {
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
      mgb::hash($_self.handle)
      );
  }];
  let cmpFunction = [{return $0.handle == $1.handle;}];
}

def PoissonRNG: MgbHashableOp<"PoissonRNG", [PoissonRNGParam]> {
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
      mgb::hash($_self.handle)
      );
  }];
  let cmpFunction = [{return $0.handle == $1.handle;}];
}

def BetaRNG: MgbHashableOp<"BetaRNG", [BetaRNGParam]> {
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
      mgb::hash($_self.handle)
      );
  }];
  let cmpFunction = [{return $0.handle == $1.handle;}];
}

def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> {
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
      mgb::hash_pair_combine(
        mgb::hash($_self.handle),
        mgb::hash($_self.dtype.enumv())
233 234
      )
    );
235
  }];
236
  let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
237 238
}

239 240 241 242 243 244 245 246 247 248 249 250 251
def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> {
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
      mgb::hash($_self.handle)
      );
  }];
  let cmpFunction = [{return $0.handle == $1.handle;}];
}

252 253 254 255 256 257 258 259 260 261 262 263
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
  let extraArguments = (ins
    MgbCompNodeAttr:$comp_node
  );
}

def Eye: MgbHashableOp<"Eye", [EyeParam]> {
  let extraArguments = (ins
    MgbCompNodeAttr:$comp_node
  );
}

264 265
def Diag: MgbHashableOp<"Diag", [DiagParam]>;

266
def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>;
267 268 269 270 271 272 273

def Concat: MgbHashableOp<"Concat", [AxisParam]> {
  let extraArguments = (ins
    MgbCompNodeAttr:$comp_node
  );
}

274 275 276 277 278
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> {
  let extraArguments = (ins
    MgbArrayAttr<MgbI32Attr>:$shape
  );
}
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301

def Identity: MgbHashableOp<"Identity">;

def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> {
  let extraArguments = (ins
    MgbStringAttr:$key,
    MgbUI32Attr:$nr_devices,
    MgbUI32Attr:$rank,
    MgbBoolAttr:$is_root,
    MgbBoolAttr:$local_grad,
    MgbStringAttr:$addr,
    MgbUI32Attr:$port,
    MgbDTypeAttr:$dtype,
    MgbStringAttr:$backend,
    MgbStringAttr:$comp_node
  );
}

def RemoteSend : MgbHashableOp<"RemoteSend"> {
  let extraArguments = (ins
    MgbStringAttr:$key,
    MgbStringAttr:$addr,
    MgbUI32Attr:$port,
302 303
    MgbUI32Attr:$rank_to,
    MgbStringAttr:$backend
304 305 306 307 308 309 310 311 312 313
  );
}

def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
  let extraArguments = (ins
    MgbStringAttr:$key,
    MgbStringAttr:$addr,
    MgbUI32Attr:$port,
    MgbUI32Attr:$rank_from,
    MgbCompNodeAttr:$cn,
314
    MgbArrayAttr<MgbI32Attr>:$shape,
315 316
    MgbDTypeAttr:$dtype,
    MgbStringAttr:$backend
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
  );
}

def NMSKeep : MgbHashableOp<"NMSKeep"> {
  let extraArguments = (ins
    MgbF32Attr:$iou_thresh,
    MgbUI32Attr:$max_output
  );
}

def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> {
  let extraArguments = (ins
    MgbArrayAttr<MgbI32Attr>:$offsets,
    MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$shapes
  );
}

def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> {
  let extraArguments = (ins
    MgbArrayAttr<MgbI32Attr>:$offsets
  );
}

def Dimshuffle: MgbHashableOp<"Dimshuffle"> {
  let inputs = (ins AnyMemRef:$input);
  let extraArguments = (ins MgbArrayAttr<MgbI32Attr>:$pattern);
  let results = (outs AnyMemRef);
}

346 347 348 349 350
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]> {
  let extraArguments = (ins
    MgbArrayAttr<MgbI32Attr>:$shape
  );
}
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384

// TODO: merge Add/Remove Axis into AxisAddRemove as megbrain?
def AddAxis: MgbHashableOp<"AddAxis"> {
  let extraArguments = (ins
    MgbArrayAttr<MgbI32Attr>:$axis
  );
}
def RemoveAxis: MgbHashableOp<"RemoveAxis"> {
  let extraArguments = (ins
    MgbArrayAttr<MgbI32Attr>:$axis
  );
}

class FancyIndexingBase<string name>: MgbHashableOp<name> {
  let extraArguments = (ins
    MgbArrayAttr<MgbTupleAttr<
      [MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items
  );
}

def Subtensor: FancyIndexingBase<"Subtensor">;
def SetSubtensor: FancyIndexingBase<"SetSubtensor">;
def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">;
def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">;
def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">;
def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">;
def MeshIndexing: FancyIndexingBase<"MeshIndexing">;
def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">;
def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">;
def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">;
def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;

def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
385
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
M
Megvii Engine Team 已提交
386
def TQT: MgbHashableOp<"TQT", [TQTParam]>;
M
Megvii Engine Team 已提交
387
def LSQ: MgbHashableOp<"LSQ", [LSQParam]>;
388
def Softmax: MgbHashableOp<"Softmax", [SoftmaxParam]>;
389 390 391 392
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
  let extraArguments = (ins
    MgbDTypeAttr:$dtype
  );
393 394 395
  let nameFunction = [{
    return to_string($_self.mode);
  }];
396 397
}

398 399
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;

400 401 402 403 404 405 406
def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
  let extraArguments = (ins
    MgbStringAttr:$buf,
    MgbSizeTAddr:$buf_size
  );
}

407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> {
  let extraArguments = (ins
    MgbStringAttr:$buf,
    MgbSizeTAddr:$buf_size
  );
}

def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
  let extraArguments = (ins
    MgbStringAttr:$buf,
    MgbSizeTAddr:$buf_size,
    MgbStringAttr:$symbol,
    MgbBoolAttr:$tensor_dim_mutable
  );
}

423 424 425 426 427 428 429
def MagicMindRuntime: MgbHashableOp<"MagicMindRuntime"> {
  let extraArguments = (ins
    MgbStringAttr:$buf,
    MgbSizeTAddr:$buf_size
  );
}

430 431
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;

432
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>;
433

434 435
def FastpathCopy: MgbHashableOp<"FastpathCopy">;

436 437 438 439 440 441 442 443 444 445 446 447
def PixelShuffle: MgbHashableOp<"PixelShuffle"> {
  let extraArguments = (ins
    MgbI32Attr:$factor
  );
}

def PixelShuffleBackward: MgbHashableOp<"PixelShuffleBackward"> {
  let extraArguments = (ins
    MgbI32Attr:$factor
  );
}

448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
def ExternOpr: MgbHashableOp<"ExternOpr"> {
  let extraArguments = (ins
    MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes,
    MgbStringAttr:$name,
    MgbStringAttr:$data,
    MgbSizeTAddr:$data_len,
    MgbArrayAttr<MgbDTypeAttr>:$output_dtypes
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
      mgb::hash_pair_combine(
        mgb::hash($_self.name),
        mgb::hash($_self.data))
      );
  }];
}

M
Megvii Engine Team 已提交
466 467
def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;

468 469
def Split: MgbHashableOp<"Split", [EmptyParam]> {
  let extraArguments = (ins
470 471
    MgbI32Attr:$axis,
    MgbI32Attr:$nsections
472 473 474
  );
}

475 476
def Padding: MgbHashableOp<"Padding", [PaddingParam]>;

477 478
def LRN: MgbHashableOp<"LRN", [LRNParam]>;

479
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>;
480 481 482

def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>;

483 484 485 486 487 488 489
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>;

def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>;

def RNN: MgbHashableOp<"RNN", [RNNParam]>;

def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>;
490

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
  let extraArguments = (ins
    MgbSizeTAddr:$handle
  );
  let hashFunction = [{
    return mgb::hash_pair_combine(
      mgb::hash($_self.dyn_typeinfo()),
      mgb::hash_pair_combine(
        mgb::hash($_self.drop_prob),
        mgb::hash($_self.handle))
      );
  }];
  let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];

}
506
#endif // MGB_OPS