From 5d49e3e9c856249b297eebdda33ec49e4bb8fe15 Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Tue, 13 Dec 2022 10:31:15 +0800 Subject: [PATCH] Enable Generic-Plugin support FP16 (#48807) --- .../tensorrt/plugin/generic_plugin.cu | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu index 86ecca9290..88bb7e5f53 100644 --- a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu @@ -300,27 +300,38 @@ bool GenericPlugin::supportsFormatCombination( int nb_outputs) TRT_NOEXCEPT { if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") { if (pos == 0) - return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + return (in_out[pos].type == nvinfer1::DataType::kFLOAT || + (isFp16Supported() && + in_out[pos].type == nvinfer1::DataType::kHALF)) && (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); if (pos == 1) return (in_out[pos].type == nvinfer1::DataType::kINT32) && (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + // output if (pos == 2) - return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && - (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + return in_out[0].type == in_out[pos].type && + in_out[0].format == in_out[pos].format; } else if (op_desc_.Type() == "scatter_nd_add") { + // input X if (pos == 0) - return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + return (in_out[pos].type == nvinfer1::DataType::kFLOAT || + (isFp16Supported() && + in_out[pos].type == nvinfer1::DataType::kHALF)) && (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + // input Index if (pos == 1) return (in_out[pos].type == nvinfer1::DataType::kINT32) && (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + // input Updates if (pos == 2) - return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + return (in_out[pos].type == nvinfer1::DataType::kFLOAT || + (isFp16Supported() && + in_out[pos].type == nvinfer1::DataType::kHALF)) && (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + // output if (pos == 3) - return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && - (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); + return in_out[0].type == in_out[pos].type && + in_out[0].format == in_out[pos].format; } else if (op_desc_.Type() == "pad3d") { return (in_out[pos].type == nvinfer1::DataType::kFLOAT || (isFp16Supported() && @@ -328,7 +339,9 @@ bool GenericPlugin::supportsFormatCombination( (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR) && (in_out[0].type == in_out[pos].type); } else { - return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && + return (in_out[pos].type == nvinfer1::DataType::kFLOAT || + (isFp16Supported() && + in_out[pos].type == nvinfer1::DataType::kHALF)) && (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); } } -- GitLab