From 92ad682f11e87f8cac0537ee62fb5ad3f673c014 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Mon, 13 Dec 2021 10:02:54 +0800 Subject: [PATCH] fix trt de/serialization and refine the data type selection (#38057) --- .../tensorrt/plugin/deformable_conv_op_plugin.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu index 70e5a7bcc7..6128f8f0e4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu @@ -55,7 +55,8 @@ void DeformableConvPlugin::serializeFromDevice( PADDLE_ENFORCE_GPU_SUCCESS( cudaMemcpy(static_cast(*hostBuffer), deviceWeights.values, deviceWeights.count * num_bytes, cudaMemcpyDeviceToHost)); - hostBuffer += deviceWeights.count * num_bytes; + *hostBuffer = + reinterpret_cast(*hostBuffer) + deviceWeights.count * num_bytes; } nvinfer1::Weights DeformableConvPlugin::deserializeToDevice( @@ -63,7 +64,7 @@ nvinfer1::Weights DeformableConvPlugin::deserializeToDevice( int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); nvinfer1::Weights w = copyToDevice(static_cast(*hostBuffer), count); - hostBuffer += count * num_bytes; + *hostBuffer = reinterpret_cast(*hostBuffer) + count * num_bytes; return w; } @@ -189,8 +190,7 @@ bool DeformableConvPlugin::supportsFormat( nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT { if (with_fp16_) { #ifdef TRT_PLUGIN_FP16_AVALIABLE - return (type == nvinfer1::DataType::kFLOAT || - type == nvinfer1::DataType::kHALF) && + return (type == nvinfer1::DataType::kHALF) && (format == nvinfer1::TensorFormat::kLINEAR); #else return (type == nvinfer1::DataType::kFLOAT) && @@ -615,7 +615,7 @@ const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT { nvinfer1::DataType DeformableConvPlugin::getOutputDataType( int index, const nvinfer1::DataType* input_type, int nb_inputs) const TRT_NOEXCEPT { - return data_type_; + return input_type[0]; } bool DeformableConvPlugin::isOutputBroadcastAcrossBatch( -- GitLab