未验证 提交 92ad682f 编写于 作者: Z zlsh80826 提交者: GitHub

fix trt de/serialization and refine the data type selection (#38057)

上级 9598b19c
...@@ -55,7 +55,8 @@ void DeformableConvPlugin::serializeFromDevice( ...@@ -55,7 +55,8 @@ void DeformableConvPlugin::serializeFromDevice(
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpy(static_cast<char*>(*hostBuffer), deviceWeights.values, cudaMemcpy(static_cast<char*>(*hostBuffer), deviceWeights.values,
deviceWeights.count * num_bytes, cudaMemcpyDeviceToHost)); deviceWeights.count * num_bytes, cudaMemcpyDeviceToHost));
hostBuffer += deviceWeights.count * num_bytes; *hostBuffer =
reinterpret_cast<char*>(*hostBuffer) + deviceWeights.count * num_bytes;
} }
nvinfer1::Weights DeformableConvPlugin::deserializeToDevice( nvinfer1::Weights DeformableConvPlugin::deserializeToDevice(
...@@ -63,7 +64,7 @@ nvinfer1::Weights DeformableConvPlugin::deserializeToDevice( ...@@ -63,7 +64,7 @@ nvinfer1::Weights DeformableConvPlugin::deserializeToDevice(
int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
nvinfer1::Weights w = nvinfer1::Weights w =
copyToDevice(static_cast<const char*>(*hostBuffer), count); copyToDevice(static_cast<const char*>(*hostBuffer), count);
hostBuffer += count * num_bytes; *hostBuffer = reinterpret_cast<const char*>(*hostBuffer) + count * num_bytes;
return w; return w;
} }
...@@ -189,8 +190,7 @@ bool DeformableConvPlugin::supportsFormat( ...@@ -189,8 +190,7 @@ bool DeformableConvPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT { nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (type == nvinfer1::DataType::kFLOAT || return (type == nvinfer1::DataType::kHALF) &&
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::TensorFormat::kLINEAR); (format == nvinfer1::TensorFormat::kLINEAR);
#else #else
return (type == nvinfer1::DataType::kFLOAT) && return (type == nvinfer1::DataType::kFLOAT) &&
...@@ -615,7 +615,7 @@ const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT { ...@@ -615,7 +615,7 @@ const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT {
nvinfer1::DataType DeformableConvPlugin::getOutputDataType( nvinfer1::DataType DeformableConvPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_type, int index, const nvinfer1::DataType* input_type,
int nb_inputs) const TRT_NOEXCEPT { int nb_inputs) const TRT_NOEXCEPT {
return data_type_; return input_type[0];
} }
bool DeformableConvPlugin::isOutputBroadcastAcrossBatch( bool DeformableConvPlugin::isOutputBroadcastAcrossBatch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册