Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
92ad682f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
92ad682f
编写于
12月 13, 2021
作者:
Z
zlsh80826
提交者:
GitHub
12月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix trt de/serialization and refine the data type selection (#38057)
上级
9598b19c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
5 addition
and
5 deletion
+5
-5
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
...id/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
+5
-5
未找到文件。
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
浏览文件 @
92ad682f
...
@@ -55,7 +55,8 @@ void DeformableConvPlugin::serializeFromDevice(
...
@@ -55,7 +55,8 @@ void DeformableConvPlugin::serializeFromDevice(
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
static_cast
<
char
*>
(
*
hostBuffer
),
deviceWeights
.
values
,
cudaMemcpy
(
static_cast
<
char
*>
(
*
hostBuffer
),
deviceWeights
.
values
,
deviceWeights
.
count
*
num_bytes
,
cudaMemcpyDeviceToHost
));
deviceWeights
.
count
*
num_bytes
,
cudaMemcpyDeviceToHost
));
hostBuffer
+=
deviceWeights
.
count
*
num_bytes
;
*
hostBuffer
=
reinterpret_cast
<
char
*>
(
*
hostBuffer
)
+
deviceWeights
.
count
*
num_bytes
;
}
}
nvinfer1
::
Weights
DeformableConvPlugin
::
deserializeToDevice
(
nvinfer1
::
Weights
DeformableConvPlugin
::
deserializeToDevice
(
...
@@ -63,7 +64,7 @@ nvinfer1::Weights DeformableConvPlugin::deserializeToDevice(
...
@@ -63,7 +64,7 @@ nvinfer1::Weights DeformableConvPlugin::deserializeToDevice(
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
int
num_bytes
=
(
data_type_
==
nvinfer1
::
DataType
::
kFLOAT
?
4
:
2
);
nvinfer1
::
Weights
w
=
nvinfer1
::
Weights
w
=
copyToDevice
(
static_cast
<
const
char
*>
(
*
hostBuffer
),
count
);
copyToDevice
(
static_cast
<
const
char
*>
(
*
hostBuffer
),
count
);
hostBuffer
+=
count
*
num_bytes
;
*
hostBuffer
=
reinterpret_cast
<
const
char
*>
(
*
hostBuffer
)
+
count
*
num_bytes
;
return
w
;
return
w
;
}
}
...
@@ -189,8 +190,7 @@ bool DeformableConvPlugin::supportsFormat(
...
@@ -189,8 +190,7 @@ bool DeformableConvPlugin::supportsFormat(
nvinfer1
::
DataType
type
,
nvinfer1
::
TensorFormat
format
)
const
TRT_NOEXCEPT
{
nvinfer1
::
DataType
type
,
nvinfer1
::
TensorFormat
format
)
const
TRT_NOEXCEPT
{
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return
(
type
==
nvinfer1
::
DataType
::
kFLOAT
||
return
(
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
#else
return
(
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
(
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
...
@@ -615,7 +615,7 @@ const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT {
...
@@ -615,7 +615,7 @@ const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT {
nvinfer1
::
DataType
DeformableConvPlugin
::
getOutputDataType
(
nvinfer1
::
DataType
DeformableConvPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_type
,
int
index
,
const
nvinfer1
::
DataType
*
input_type
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
data_type_
;
return
input_type
[
0
]
;
}
}
bool
DeformableConvPlugin
::
isOutputBroadcastAcrossBatch
(
bool
DeformableConvPlugin
::
isOutputBroadcastAcrossBatch
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录