tensor_shape.td 1.4 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4
#ifdef INFRT_OPS
#else
#define INFRT_OPS

5
include "paddle/infrt/dialect/infrt/ir/infrt_base.td"
Y
Yan Chunwei 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
include "paddle/infrt/dialect/tensor_shape_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Base class for the operation in the TensorShape dialect
class TS_Op<string mnemonic, list<OpTrait> traits = []> :
    Op<TensorShapeDialect, mnemonic, traits> {
  let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }];
  let printer = " return infrt::dialect::printOpWithOperands(p, *this)" ";";
}

def TS_BuildShapeOp : TS_Op<"build_shape", [NoSideEffect]> {
  let summary = "Build tensor shape operation";
  let description = [{
    An operation that builds a tensor shape of given ranks and extents.
  }];

  let arguments = (ins I64ArrayAttr:$value);
  let results = (outs TS_Shape:$output);
  let assemblyFormat = "$value attr-dict";
}

def TS_GetNumElementsOp : TS_Op<"get_num_elements"> {
  let summary = "Returns the number of elements in the shape";

  let description = [{
    An operation that returns the number of elements in the given shape.
  }];

  let arguments = (ins TS_Shape);
  let results = (outs I64);
  let assemblyFormat = "operands attr-dict";
}

def TS_PrintShapeOp : TS_Op<"print_shape"> {
  let summary = "Print tensor shape operation";
  let description = [{
    An operation that prints a tensor shape.
  }];

  let arguments = (ins TS_Shape:$shape);
  let assemblyFormat = "operands attr-dict";
}

#endif