Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1255e7d6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1255e7d6
编写于
2月 24, 2022
作者:
W
Wangzheee
提交者:
GitHub
2月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-Inference] fix special_slice plugin (#39875)
* fix plugin: special slice for ernie
上级
ce207c3a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
10 addition
and
9 deletion
+10
-9
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
...e/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
+10
-9
未找到文件。
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
浏览文件 @
1255e7d6
...
@@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
...
@@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SpecialSliceKernel
(
const
T
*
slice_input
,
__global__
void
SpecialSliceKernel
(
const
T
*
slice_input
,
const
int32_t
*
cu_seqlens
,
T
*
output
)
{
const
int32_t
*
cu_seqlens
,
T
*
output
)
{
const
int
hidden
=
blockDim
.
x
*
gridDim
.
y
;
const
int
hidden
=
blockDim
.
x
*
gridDim
.
x
;
const
int
batch
=
block
Idx
.
x
;
const
int
hidden_id
=
blockIdx
.
x
*
blockDim
.
x
+
thread
Idx
.
x
;
const
int
local_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
y
;
output
[
batch
*
hidden
+
local_idx
]
=
output
[
batch
_id
*
hidden
+
hidden_id
]
=
slice_input
[
cu_seqlens
[
batch
]
*
hidden
+
local_idx
];
slice_input
[
cu_seqlens
[
batch
_id
]
*
hidden
+
hidden_id
];
}
}
int
SpecialSlicePluginDynamic
::
enqueue
(
int
SpecialSlicePluginDynamic
::
enqueue
(
...
@@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue(
...
@@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue(
"hidden should be multiple of 128."
));
"hidden should be multiple of 128."
));
constexpr
int
num_threads
=
128
;
constexpr
int
num_threads
=
128
;
const
dim3
blocks
(
out_dims
.
d
[
0
],
hidden
/
num_threads
);
const
half
*
slice_input
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
half
*
slice_input
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
cu_seqlens
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
const
int32_t
*
cu_seqlens
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SpecialSliceKernel
<<<
blocks
,
num_threads
,
0
,
stream
>>>
(
slice_input
,
const
int32_t
num_blocks_x
=
hidden
/
num_threads
;
cu_seqlens
,
output
);
const
int32_t
num_blocks_y
=
out_dims
.
d
[
0
];
// batchs
const
dim3
num_blocks
(
num_blocks_x
,
num_blocks_y
);
// blocks
SpecialSliceKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
slice_input
,
cu_seqlens
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录