Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b779d2b8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b779d2b8
编写于
5月 31, 2022
作者:
W
Wilber
提交者:
GitHub
5月 31, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix slice plugin (#43110)
上级
12d8a567
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
16 addition
and
36 deletion
+16
-36
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
+14
-32
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
+2
-4
未找到文件。
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
浏览文件 @
b779d2b8
...
...
@@ -56,8 +56,6 @@ SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends,
std
::
vector
<
int
>
axes
,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
)
{
with_fp16_
=
with_fp16
;
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
}
SlicePlugin
::
SlicePlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{
...
...
@@ -66,15 +64,10 @@ SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
ends_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
axes_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
offset_info_
);
}
SlicePlugin
::~
SlicePlugin
()
{
cudaStreamDestroy
(
copy_stream_
);
cudaEventDestroy
(
copy_event_
);
cudaFree
(
offset_temp_data_
);
}
SlicePlugin
::~
SlicePlugin
()
{
cudaFree
(
offset_temp_data_
);
}
SlicePlugin
*
SlicePlugin
::
clone
()
const
TRT_NOEXCEPT
{
return
new
SlicePlugin
(
starts_
,
ends_
,
axes_
,
with_fp16_
);
...
...
@@ -159,11 +152,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
}
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
copy_stream_
);
cudaEventRecord
(
copy_event_
,
copy_stream_
);
cudaStreamWaitEvent
(
stream
,
copy_event_
,
0
);
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
int
threads
=
256
;
int
blocks
=
(
out_num
+
threads
-
1
)
/
threads
;
...
...
@@ -190,7 +179,7 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
size_t
SlicePlugin
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
getBaseSerializationSize
()
+
SerializedSize
(
starts_
)
+
SerializedSize
(
ends_
)
+
SerializedSize
(
axes_
)
+
SerializedSize
(
with_fp16_
);
SerializedSize
(
with_fp16_
)
+
SerializedSize
(
offset_info_
)
;
}
void
SlicePlugin
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
...
...
@@ -199,6 +188,7 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
ends_
);
SerializeValue
(
&
buffer
,
axes_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
offset_info_
);
}
// Dynamic Plugin below.
...
...
@@ -209,8 +199,6 @@ SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
),
decrease_axis_
(
decrease_axis
)
{
with_fp16_
=
with_fp16
;
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
}
SlicePluginDynamic
::
SlicePluginDynamic
(
void
const
*
serialData
,
...
...
@@ -220,13 +208,10 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
axes_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
decrease_axis_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
offset_info_
);
}
void
SlicePluginDynamic
::
destroy
()
TRT_NOEXCEPT
{
cudaStreamDestroy
(
copy_stream_
);
cudaEventDestroy
(
copy_event_
);
cudaFree
(
offset_temp_data_
);
delete
this
;
}
...
...
@@ -236,7 +221,7 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t
SlicePluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
size
=
SerializedSize
(
starts_
)
+
SerializedSize
(
ends_
)
+
SerializedSize
(
axes_
)
+
SerializedSize
(
decrease_axis_
)
+
SerializedSize
(
with_fp16_
);
SerializedSize
(
with_fp16_
)
+
SerializedSize
(
offset_info_
)
;
return
size
;
}
...
...
@@ -247,6 +232,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
axes_
);
SerializeValue
(
&
buffer
,
decrease_axis_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
offset_info_
);
}
nvinfer1
::
DimsExprs
SlicePluginDynamic
::
getOutputDimensions
(
...
...
@@ -361,23 +347,19 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
offsets
[
axes_
[
i
]]
=
starts_
[
i
];
}
std
::
vector
<
int
>
offset_info
;
offset_info_
.
resize
(
num_dims
*
3
)
;
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
offset_info
.
push_back
(
offsets
[
i
])
;
offset_info
.
push_back
(
extends
[
i
])
;
offset_info
.
push_back
(
seg_offsets
[
i
])
;
offset_info
_
[
i
*
3
+
0
]
=
offsets
[
i
]
;
offset_info
_
[
i
*
3
+
1
]
=
extends
[
i
]
;
offset_info
_
[
i
*
3
+
2
]
=
seg_offsets
[
i
]
;
}
if
(
offset_temp_data_
==
nullptr
)
{
cudaMalloc
(
&
offset_temp_data_
,
3
*
num_dims
*
sizeof
(
int
));
}
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
copy_stream_
);
cudaEventRecord
(
copy_event_
,
copy_stream_
);
cudaStreamWaitEvent
(
stream
,
copy_event_
,
0
);
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info_
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
int
threads
=
256
;
int
blocks
=
(
out_num
+
threads
-
1
)
/
threads
;
...
...
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
浏览文件 @
b779d2b8
...
...
@@ -64,8 +64,7 @@ class SlicePlugin : public PluginTensorRT {
std
::
vector
<
int
>
ends_
;
std
::
vector
<
int
>
axes_
;
int
*
offset_temp_data_
{
nullptr
};
cudaEvent_t
copy_event_
;
cudaStream_t
copy_stream_
;
std
::
vector
<
int
>
offset_info_
;
};
class
SlicePluginCreator
:
public
TensorRTPluginCreator
{
...
...
@@ -144,8 +143,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std
::
vector
<
int
>
axes_
;
int
decrease_axis_
;
int
*
offset_temp_data_
{
nullptr
};
cudaEvent_t
copy_event_
;
cudaStream_t
copy_stream_
;
std
::
vector
<
int
>
offset_info_
;
};
class
SlicePluginDynamicCreator
:
public
TensorRTPluginCreator
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录