pd_op_base.td 2.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8
// This file defines some basic elements of Paddle(alias pd) dialect.
// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td

#ifndef PD_OP_BASE
#define PD_OP_BASE

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
9
include "paddle/infrt/dialect/infrt/infrt_ops_base.td"
Y
Yan Chunwei 已提交
10 11 12 13 14 15 16 17 18 19

def PD_Dialect : Dialect {
  let name = "pd";

  let description = [{
    The PaddlePaddle dialect.

    This dialect contains the PaddlePaddle operators.
  }];

20
  let cppNamespace = "mlir::pd";
Y
Yan Chunwei 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
}

class PD_Op<string mnemonic, list<OpTrait> traits = []> :
      Op<PD_Dialect, mnemonic, traits>;


class PD_PaddleAttr <string name, string description> :
      Attr<CPred<"$_self.isa<mlir::pd::" # name # "Attr>()">,
          "PaddlePaddle " # description # " attribute">;


//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
//===----------------------------------------------------------------------===//

def PD_PDDialectType : Type<CPred<"$_self.isa<mlir::pd::PDType>()">, "PaddlePaddle type">;

class PD_PaddleType <string name, string description> :
      Type<CPred<"$_self.isa<mlir::pd::" # name #"Type>()">,
         "Paddle " # description # " type">,
      BuildableType<"getType<mlir::pd::" # name # "Type>()">;

//===----------------------------------------------------------------------===//
// Integer types
def PD_Bool : AnyTypeOf<[I<1>], "bool">;
def PD_Int8 : AnyTypeOf<[I8], "8-bit integer">;
def PD_Int16 : AnyTypeOf<[I16], "16-bit integer">;
def PD_Int32 : AnyTypeOf<[I32], "32-bit integer">;
def PD_Int64 : AnyTypeOf<[I64], "64-bit integer">;

def PD_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">;
def PD_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">;
def PD_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">;
def PD_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">;

def PD_SInt : AnyTypeOf<[PD_Int8, PD_Int16, PD_Int32, PD_Int64], "signed integer">;
def PD_UInt : AnyTypeOf<[PD_UInt8, PD_UInt16, PD_UInt32, PD_UInt64], "unsigned integer">;
def PD_Int : AnyTypeOf<[PD_SInt, PD_UInt], "integer">;

// Float types
def PD_Float16 : AnyTypeOf<[F16], "16-bit float">;
def PD_Float32 : AnyTypeOf<[F32], "32-bit float">;
def PD_Float64 : AnyTypeOf<[F64], "64-bit float">;

def PD_Float : AnyTypeOf<[PD_Float16, PD_Float32, PD_Float64], "floating-point">;


// Tensor types

def PD_ElementType : Type<Or<[PD_Float.predicate,
                              PD_Bool.predicate,
                              PD_Int.predicate]>,
                              "pd.dtype">;

75 76 77 78
// def PD_Tensor : TensorOf<[PD_ElementType]>;
def PD_Tensor1 : TensorOf<[PD_ElementType]>;

def PD_Tensor :  AnyTypeOf<[PD_Tensor1, LoDTensor],"pd.ttype">;
Y
Yan Chunwei 已提交
79

80
def PD_Tensor_Array : VectorOf<[PD_Tensor]>;
Y
Yan Chunwei 已提交
81 82

#endif // PD_OP_BASE