diff --git a/paddle/infrt/backends/tensorrt/trt_engine.cc b/paddle/infrt/backends/tensorrt/trt_engine.cc index 72d98d865a69eaed654b0c94ddc8578a58f8b298..a2d4954618986e66fccbc8cb67faf612e975a596 100644 --- a/paddle/infrt/backends/tensorrt/trt_engine.cc +++ b/paddle/infrt/backends/tensorrt/trt_engine.cc @@ -210,12 +210,16 @@ bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build, case PrecisionConstraints::kNONE: // It's the default for TensorRT. break; +#if IS_TRT_VERSION_GE(8200) case PrecisionConstraints::kOBEY: config.setFlag(BuilderFlag::kOBEY_PRECISION_CONSTRAINTS); break; case PrecisionConstraints::kPREFER: config.setFlag(BuilderFlag::kPREFER_PRECISION_CONSTRAINTS); break; +#endif // IS_TRT_VERSION_GE(8200) + default: + break; } // TODO(TRT): DLA config. diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td index 6467c1285f85e0c8bfca7b873ce64a09a52074ff..1c5ba68936837664ca39bcf3ffa7e9784e890cf8 100644 --- a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -6,23 +6,23 @@ include "paddle/infrt/dialect/infrt/ir/infrt_base.td" include "paddle/infrt/dialect/pd/ir/pd_ops.td" include "paddle/infrt/dialect/tensorrt/trt_ops.td" +class TRT_createNvinferEnumAttr : NativeCodeCall< + "infrt::trt::createNvinferEnumAttr($_builder, STRING_TO_ENUM_VALUE(" # enum_type # "::" # enum_value # "))">; + def PD2TRT_Matmul_Lower : Pat< (PD_MatmulOp $X, $Y, $transpose_X, $transpose_Y, ConstantAttr), (TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>; -//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ElementWiseOperation::kSUM def PD2TRT_ElementwiseAdd_Lower : Pat< (PD_Elementwise_addOp $X, $Y, ConstantAttr), - (TRT_ElementWiseOp $X, $Y, (INFRT_createSI32Attr<"0">)/*kSUM*/)>; + (TRT_ElementWiseOp $X, $Y, (TRT_createNvinferEnumAttr<"nvinfer1::ElementWiseOperation", "kSUM">))>; -//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kRELU def PD2TRT_Relu_Lower : Pat< (PD_ReluOp $X), - (TRT_ActivationOp $X, (INFRT_createSI32Attr<"0">)/*kRELU*/, (INFRT_createF32Attr<"0.0">), (INFRT_createF32Attr<"0.0">))>; + (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kRELU">), (INFRT_createF32Attr<"0.0">), (INFRT_createF32Attr<"0.0">))>; -//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kCLIP def PD2TRT_Relu6_Lower : Pat< (PD_Relu6Op $X, $threshold), - (TRT_ActivationOp $X, (INFRT_createSI32Attr<"8">)/*kCLIP*/, (INFRT_createF32Attr<"0.0">), $threshold)>; + (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>; #endif // PD_LOWER_TO_TRT diff --git a/paddle/infrt/dialect/tensorrt/trt_op_base.td b/paddle/infrt/dialect/tensorrt/trt_op_base.td index 128960ee03e03029ac3681b0184e4359c0436dde..81cb882fdf21bcda379ca8f95229c19285a8aa39 100755 --- a/paddle/infrt/dialect/tensorrt/trt_op_base.td +++ b/paddle/infrt/dialect/tensorrt/trt_op_base.td @@ -22,59 +22,8 @@ def TRT_Dialect : Dialect { class TRT_Op traits = []> : Op; - -class TRT_PaddleAttr : - Attr()">, - "PaddlePaddle " # description # " attribute">; - def TRT_EngineType : Type()">, "!trt.engine">, BuildableType<"getType<::infrt::trt::EngineType>()">; -//===----------------------------------------------------------------------===// -// PaddlePaddle type definitions -//===----------------------------------------------------------------------===// - -def TRT_TRTDialectType : Type()">, "PaddlePaddle type">; - -class TRT_PaddleType : - Type()">, - "Paddle " # description # " type">, - BuildableType<"getType()">; - -//===----------------------------------------------------------------------===// -// Integer types -def TRT_Bool : AnyTypeOf<[I<1>], "bool">; -def TRT_Int8 : AnyTypeOf<[I8], "8-bit integer">; -def TRT_Int16 : AnyTypeOf<[I16], "16-bit integer">; -def TRT_Int32 : AnyTypeOf<[I32], "32-bit integer">; -def TRT_Int64 : AnyTypeOf<[I64], "64-bit integer">; - -def TRT_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">; -def TRT_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">; -def TRT_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">; -def TRT_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">; - -def TRT_SInt : AnyTypeOf<[TRT_Int8, TRT_Int16, TRT_Int32, TRT_Int64], "signed integer">; -def TRT_UInt : AnyTypeOf<[TRT_UInt8, TRT_UInt16, TRT_UInt32, TRT_UInt64], "unsigned integer">; -def TRT_Int : AnyTypeOf<[TRT_SInt, TRT_UInt], "integer">; - -// Float types -def TRT_Float16 : AnyTypeOf<[F16], "16-bit float">; -def TRT_Float32 : AnyTypeOf<[F32], "32-bit float">; -def TRT_Float64 : AnyTypeOf<[F64], "64-bit float">; - -def TRT_Float : AnyTypeOf<[TRT_Float16, TRT_Float32, TRT_Float64], "floating-point">; - - -// Tensor types - -def TRT_ElementType : Type, - "trt.dtype">; - -def TRT_Tensor : TensorOf<[TRT_ElementType]>; - - #endif // TRT_OP_BASE diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 1e50b772e081705ec81bd6b093cd9be9b1987bf6..b7032a2aa25c92e5f8d04414a27da7d6fa232d98 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -27,6 +27,32 @@ namespace infrt { namespace trt { +#ifdef INFRT_WITH_TRT + +#define STRING_TO_ENUM_TYPE(enum_type) enum_type +#define STRING_TO_ENUM_VALUE(enum_value) enum_value +#include + +#else // INFRT_WITH_TRT + +#define STRING_TO_ENUM_TYPE(enum_type) std::string +#define STRING_TO_ENUM_VALUE(enum_value) #enum_value + +#endif // INFRT_WITH_TRT + +template +::mlir::IntegerAttr createNvinferEnumAttr(::mlir::PatternRewriter &rewriter, + T enum_value) { + return rewriter.getSI32IntegerAttr((int32_t)enum_value); +} + +template <> +::mlir::IntegerAttr createNvinferEnumAttr( + ::mlir::PatternRewriter &rewriter, std::string enum_value) { + (void)enum_value; + return rewriter.getSI32IntegerAttr(-1); +} + #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT struct PD2TRT_GraphLower : public ::mlir::RewritePattern { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 803a11ed5b7e5ce46211a85471536c0300d42630..d0585532adf9ea1c2ac664a726f8444f268913db 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -20,15 +20,6 @@ def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator< let results = (outs TRT_EngineType:$engine); } -def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> { - let summary = "trt execute Op"; - let description = [{ - Describe a tensorrt runtime. - }]; - let arguments = (ins TRT_EngineType:$engine, Variadic:$inputs); - let results = (outs Variadic:$output); -} - def TRT_EngineComputeOp : TRT_Op<"compute", [NoSideEffect]> { let summary = "trt compute engine"; let description = [{