diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b2764ca61c11219e5546867813157b7f05ee3ce8..d53a8923af6120adb460d95fc81820b6dfa03a60 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -54,6 +54,8 @@ TRT_DT FluidDataType2TRT(FluidDT type) { return TRT_DT::kFLOAT; case FluidDT::VarType_Type_INT32: return TRT_DT::kINT32; + case FluidDT::VarType_Type_FP16: + return TRT_DT::kHALF; default: return TRT_DT::kINT32; } diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index e05b4de65214c8cf55d099fccc7c18370b2312b7..0a71875d8931ef80846aa7e0c95ce1beab86fd7c 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -79,6 +79,28 @@ static void RuntimeStaticShapeCheck(std::vector runtime_input_shape, model_input_shape_str, runtime_input_shape_str)); } +static paddle::experimental::DataType TRT2FluidDataType( + nvinfer1::DataType type) { + switch (type) { + case nvinfer1::DataType::kFLOAT: + return paddle::experimental::DataType::FLOAT32; + case nvinfer1::DataType::kINT32: + return paddle::experimental::DataType::INT32; + case nvinfer1::DataType::kHALF: + return paddle::experimental::DataType::FLOAT16; + case nvinfer1::DataType::kINT8: + return paddle::experimental::DataType::INT8; +#if IS_TRT_VERSION_GE(7000) + case nvinfer1::DataType::kBOOL: + return paddle::experimental::DataType::BOOL; +#endif + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "unknown fluid datatype in Fluid op converter")); + return paddle::experimental::DataType::FLOAT32; + } +} + static void RuntimeDynamicShapeCheck( const std::string &x, const std::vector &runtime_input_shape, const std::vector &min_input_shape, @@ -520,9 +542,12 @@ class TensorRTEngineOp : public framework::OperatorBase { buffers[bind_index] = static_cast(t.data()); } else if (type == framework::proto::VarType::INT32) { buffers[bind_index] = static_cast(t.data()); + } else if (type == framework::proto::VarType::FP16) { + buffers[bind_index] = static_cast(t.data()); } else { - PADDLE_THROW(platform::errors::Fatal( - "The TRT Engine OP only support float/int32_t/int64_t input.")); + PADDLE_THROW( + platform::errors::Fatal("The TRT Engine OP only support " + "float/int32_t/int64_t/float16 input.")); } } @@ -570,9 +595,10 @@ class TensorRTEngineOp : public framework::OperatorBase { "than the number of bindings, but got binding " "index = %d, number of bindings = %d.", bind_index, num_bindings)); - buffers[bind_index] = - static_cast(fluid_t->mutable_data(dev_place)); - + auto trt_type = engine->engine()->getBindingDataType(bind_index); + // get adr and set type + buffers[bind_index] = static_cast( + fluid_t->mutable_data(dev_place, TRT2FluidDataType(trt_type))); output_index += 1; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_emb_eltwise_layernorm.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_emb_eltwise_layernorm.py index 356a2c942df0d8cc5d1f016a3b2f4a284227990f..1eecf9c0497a196666c4b30af721c7b68d77b0fd 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_emb_eltwise_layernorm.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_emb_eltwise_layernorm.py @@ -244,28 +244,16 @@ class TrtConvertEmbEltwiseLayernormTest1(TrtLayerAutoScanTest): self.trt_param.precision = paddle_infer.PrecisionType.Float32 yield self.create_inference_config(), (0, 5), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (0, 5), 1e-5 + yield self.create_inference_config(), (0, 5), 2e-2 # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 yield self.create_inference_config(), (1, 4), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 4), 1e-5 - - def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if self.trt_param.precision == paddle_infer.PrecisionType.Half and len( - self.dynamic_shape.min_input_shape) != 0: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output has diff between gpu and trt when dynamic fp16 mode.") + yield self.create_inference_config(), (1, 4), 2e-2 def test(self): - self.add_skip_trt_case() self.run_test()