Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
82fb63eb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82fb63eb
编写于
10月 29, 2021
作者:
W
wangxinxin08
提交者:
GitHub
10月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dcnv2 trt8 compile error (#36850)
上级
f3ee5c99
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
4 addition
and
6 deletion
+4
-6
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
...id/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
+1
-3
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
...uid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
+3
-3
未找到文件。
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu
浏览文件 @
82fb63eb
...
...
@@ -360,7 +360,7 @@ void gemm_impl<half>(cublasHandle_t handle, cublasOperation_t transa,
template
<
typename
T
>
int
DeformableConvPlugin
::
enqueue_impl
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
const
T
*
input
=
reinterpret_cast
<
const
T
*>
(
inputs
[
0
]);
const
T
*
offset
=
reinterpret_cast
<
const
T
*>
(
inputs
[
1
]);
...
...
@@ -527,8 +527,6 @@ nvinfer1::IPluginV2Ext* DeformableConvPlugin::clone() const TRT_NOEXCEPT {
offset_dim_
,
mask_dim_
,
output_dim_
);
}
DeformableConvPluginCreator
::
DeformableConvPluginCreator
()
TRT_NOEXCEPT
{}
void
DeformableConvPluginCreator
::
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
{
namespace_
=
std
::
string
(
lib_namespace
);
...
...
paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h
浏览文件 @
82fb63eb
...
...
@@ -91,8 +91,8 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {
private:
template
<
typename
T
>
int
enqueue_impl
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
);
int
enqueue_impl
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
);
nvinfer1
::
Weights
copyToDevice
(
const
void
*
hostData
,
size_t
count
);
void
serializeFromDevice
(
void
**
hostBuffer
,
const
nvinfer1
::
Weights
&
deviceWeights
)
const
;
...
...
@@ -119,7 +119,7 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {
class
DeformableConvPluginCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
DeformableConvPluginCreator
();
DeformableConvPluginCreator
()
=
default
;
~
DeformableConvPluginCreator
()
override
=
default
;
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录