Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6b59a073
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6b59a073
编写于
9月 19, 2022
作者:
W
Wangzheee
提交者:
GitHub
9月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix_recover_remove_padding kernel (#46050) (#46198)
上级
db368d5b
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
44 addition
and
4 deletion
+44
-4
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
...fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
+22
-1
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
.../fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
+22
-3
未找到文件。
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
浏览文件 @
6b59a073
...
...
@@ -118,7 +118,28 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
const
int32_t
num_threads
=
256
;
int32_t
num_threads
;
if
(
input0_desc
.
dims
.
d
[
1
]
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
const
dim3
num_blocks
(
input1_desc
.
dims
.
d
[
0
]
-
1
,
input2_desc
.
dims
.
d
[
1
],
...
...
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
浏览文件 @
6b59a073
...
...
@@ -110,10 +110,29 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
const
auto
input0_desc
=
inputDesc
[
0
];
const
int32_t
num_threads
=
256
;
int32_t
num_threads
;
if
(
input0_desc
.
dims
.
d
[
2
]
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
const
dim3
num_blocks
(
input0_desc
.
dims
.
d
[
0
],
input0_desc
.
dims
.
d
[
1
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录