From 033d736dbffc954670b89b72add5d7d6ece0ac0b Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Fri, 9 Jul 2021 15:53:20 +0800 Subject: [PATCH] fix output data type selection (#34040) --- .../inference/tensorrt/plugin/anchor_generator_op_plugin.cu | 4 ++-- paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu | 2 +- paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu index 8e9845183b3..30fcc9e7014 100644 --- a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu @@ -219,7 +219,7 @@ const char* AnchorGeneratorPlugin::getPluginNamespace() const { nvinfer1::DataType AnchorGeneratorPlugin::getOutputDataType( int index, const nvinfer1::DataType* input_type, int nb_inputs) const { - return data_type_; + return input_type[0]; } bool AnchorGeneratorPlugin::isOutputBroadcastAcrossBatch( @@ -460,7 +460,7 @@ int AnchorGeneratorPluginDynamic::enqueue( nvinfer1::DataType AnchorGeneratorPluginDynamic::getOutputDataType( int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { - return data_type_; + return inputTypes[0]; } const char* AnchorGeneratorPluginDynamic::getPluginType() const { diff --git a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu index 6e7ed0054f5..61e9144b9c8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu @@ -304,7 +304,7 @@ int RoiAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, nvinfer1::DataType RoiAlignPluginDynamic::getOutputDataType( int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { - return data_type_; + return inputTypes[0]; } const char* RoiAlignPluginDynamic::getPluginType() const { diff --git a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu index f9767f38559..05ecc283628 100644 --- a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu @@ -299,7 +299,7 @@ const char* YoloBoxPlugin::getPluginNamespace() const { nvinfer1::DataType YoloBoxPlugin::getOutputDataType( int index, const nvinfer1::DataType* input_type, int nb_inputs) const { - return data_type_; + return input_type[0]; } bool YoloBoxPlugin::isOutputBroadcastAcrossBatch(int output_index, -- GitLab