pd_lower_to_trt.td 1.3 KB
Newer Older
1 2 3 4
#ifndef PD_LOWER_TO_TRT
#define PD_LOWER_TO_TRT

include "mlir/Interfaces/SideEffectInterfaces.td"
5
include "paddle/infrt/dialect/infrt/ir/infrt_base.td"
6
include "paddle/infrt/dialect/pd/ir/pd_ops.td"
7 8
include "paddle/infrt/dialect/tensorrt/trt_ops.td"

9 10 11
class TRT_createNvinferEnumAttr<string enum_type, string enum_value> : NativeCodeCall<
    "infrt::trt::createNvinferEnumAttr<STRING_TO_ENUM_TYPE(" # enum_type # ")>($_builder, STRING_TO_ENUM_VALUE(" # enum_type # "::" # enum_value # "))">;

12
def PD2TRT_Matmul_Lower : Pat<
13
        (PD_MatmulOp $X, $Y, $transpose_X, $transpose_Y, ConstantAttr<F32Attr, "1.0">),
14 15 16 17
        (TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>;

def PD2TRT_ElementwiseAdd_Lower : Pat<
        (PD_Elementwise_addOp $X, $Y, ConstantAttr<SI32Attr, "-1">),
18
        (TRT_ElementWiseOp $X, $Y, (TRT_createNvinferEnumAttr<"nvinfer1::ElementWiseOperation", "kSUM">))>;
19 20 21

def PD2TRT_Relu_Lower : Pat<
        (PD_ReluOp $X),
22
        (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kRELU">), (INFRT_createF32Attr<"0.0">), (INFRT_createF32Attr<"0.0">))>;
23 24 25

def PD2TRT_Relu6_Lower : Pat<
        (PD_Relu6Op $X, $threshold),
26
        (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>;
27 28

#endif // PD_LOWER_TO_TRT