未验证 提交 92ad682f 编写于 作者: Z zlsh80826 提交者: GitHub

fix trt de/serialization and refine the data type selection (#38057)

上级 9598b19c
......@@ -55,7 +55,8 @@ void DeformableConvPlugin::serializeFromDevice(
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpy(static_cast<char*>(*hostBuffer), deviceWeights.values,
deviceWeights.count * num_bytes, cudaMemcpyDeviceToHost));
hostBuffer += deviceWeights.count * num_bytes;
*hostBuffer =
reinterpret_cast<char*>(*hostBuffer) + deviceWeights.count * num_bytes;
}
nvinfer1::Weights DeformableConvPlugin::deserializeToDevice(
......@@ -63,7 +64,7 @@ nvinfer1::Weights DeformableConvPlugin::deserializeToDevice(
int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
nvinfer1::Weights w =
copyToDevice(static_cast<const char*>(*hostBuffer), count);
hostBuffer += count * num_bytes;
*hostBuffer = reinterpret_cast<const char*>(*hostBuffer) + count * num_bytes;
return w;
}
......@@ -189,8 +190,7 @@ bool DeformableConvPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
return (type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::TensorFormat::kLINEAR);
#else
return (type == nvinfer1::DataType::kFLOAT) &&
......@@ -615,7 +615,7 @@ const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT {
nvinfer1::DataType DeformableConvPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_type,
int nb_inputs) const TRT_NOEXCEPT {
return data_type_;
return input_type[0];
}
bool DeformableConvPlugin::isOutputBroadcastAcrossBatch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册