pd_op_base.td 2.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
// This file defines some basic elements of Paddle(alias pd) dialect.
// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td

#ifndef PD_OP_BASE
#define PD_OP_BASE

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def PD_Dialect : Dialect {
  let name = "pd";

  let description = [{
    The PaddlePaddle dialect.

    This dialect contains the PaddlePaddle operators.
  }];

  let cppNamespace = "::mlir::pd";
}

class PD_Op<string mnemonic, list<OpTrait> traits = []> :
      Op<PD_Dialect, mnemonic, traits>;


class PD_PaddleAttr <string name, string description> :
      Attr<CPred<"$_self.isa<mlir::pd::" # name # "Attr>()">,
          "PaddlePaddle " # description # " attribute">;


//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
//===----------------------------------------------------------------------===//

def PD_PDDialectType : Type<CPred<"$_self.isa<mlir::pd::PDType>()">, "PaddlePaddle type">;

class PD_PaddleType <string name, string description> :
      Type<CPred<"$_self.isa<mlir::pd::" # name #"Type>()">,
         "Paddle " # description # " type">,
      BuildableType<"getType<mlir::pd::" # name # "Type>()">;

//===----------------------------------------------------------------------===//
// Integer types
def PD_Bool : AnyTypeOf<[I<1>], "bool">;
def PD_Int8 : AnyTypeOf<[I8], "8-bit integer">;
def PD_Int16 : AnyTypeOf<[I16], "16-bit integer">;
def PD_Int32 : AnyTypeOf<[I32], "32-bit integer">;
def PD_Int64 : AnyTypeOf<[I64], "64-bit integer">;

def PD_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">;
def PD_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">;
def PD_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">;
def PD_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">;

def PD_SInt : AnyTypeOf<[PD_Int8, PD_Int16, PD_Int32, PD_Int64], "signed integer">;
def PD_UInt : AnyTypeOf<[PD_UInt8, PD_UInt16, PD_UInt32, PD_UInt64], "unsigned integer">;
def PD_Int : AnyTypeOf<[PD_SInt, PD_UInt], "integer">;

// Float types
def PD_Float16 : AnyTypeOf<[F16], "16-bit float">;
def PD_Float32 : AnyTypeOf<[F32], "32-bit float">;
def PD_Float64 : AnyTypeOf<[F64], "64-bit float">;

def PD_Float : AnyTypeOf<[PD_Float16, PD_Float32, PD_Float64], "floating-point">;


// Tensor types

def PD_ElementType : Type<Or<[PD_Float.predicate,
                              PD_Bool.predicate,
                              PD_Int.predicate]>,
                              "pd.dtype">;

def PD_Tensor : TensorOf<[PD_ElementType]>;


#endif // PD_OP_BASE