Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
28c36d86
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
28c36d86
编写于
3月 11, 2021
作者:
S
Shang Zhizhou
提交者:
GitHub
3月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix ernie_varlen when cutting head (#31497) (#31512)
上级
b54640bb
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
29 addition
and
22 deletion
+29
-22
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+24
-22
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
...e/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
+5
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
28c36d86
...
@@ -49,14 +49,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -49,14 +49,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
weight_t
->
numel
()
*
sizeof
(
float
));
// (hidden
, 3, all_head_size
)
// (hidden
_in, 3, hidden_out
)
auto
weight_dims
=
weight_t
->
dims
();
auto
weight_dims
=
weight_t
->
dims
();
int
hidden
=
weight_dims
[
0
];
// channels_in
int
hidden
_in
=
weight_dims
[
0
];
// channels_in
int
three
=
weight_dims
[
1
];
// channels_out
int
three
=
weight_dims
[
1
];
// channels_out
int
all_head_size
=
weight_dims
[
2
];
// channels_out
int
hidden_out
=
weight_dims
[
2
];
// channels_out
int
m
=
hidden
;
int
m
=
hidden
_in
;
int
n
=
three
*
all_head_size
;
int
n
=
three
*
hidden_out
;
auto
tranpose_weight
=
[](
const
float
*
src
,
float
*
dst
,
int
m
,
int
n
)
{
auto
tranpose_weight
=
[](
const
float
*
src
,
float
*
dst
,
int
m
,
int
n
)
{
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
...
@@ -72,21 +72,23 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -72,21 +72,23 @@ class MultiheadMatMulOpConverter : public OpConverter {
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
use_oss
())
{
if
(
engine_
->
use_oss
())
{
int
head_size
=
hidden
/
head_number
;
int
head_size
=
hidden_out
/
head_number
;
// [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin]
// [3, head_number, head_size, hidden_in] -> [head_number, 3, head_size,
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
// hidden_in]
int
H
)
{
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
three
,
const
int
HNH
=
H
*
N
*
H
;
int
head_number
,
int
head_size
,
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
int
hidden_in
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
int
HH
=
head_size
*
hidden_in
;
for
(
int
hnh
=
0
;
hnh
<
HNH
;
++
hnh
)
{
for
(
int
i
=
0
;
i
<
three
;
++
i
)
{
dst
[
n
*
3
*
HNH
+
i
*
HNH
+
hnh
]
=
for
(
int
n
=
0
;
n
<
head_number
;
++
n
)
{
src
[
i
*
N
*
HNH
+
n
*
HNH
+
hnh
];
for
(
int
hh
=
0
;
hh
<
HH
;
++
hh
)
{
dst
[
n
*
three
*
HH
+
i
*
HH
+
hh
]
=
src
[
i
*
head_number
*
HH
+
n
*
HH
+
hh
];
}
}
}
}
}
}
};
};
// [3,
N, H] -> [N, 3, H
]
// [3,
head_number, head_size] -> [head_number, 3, head_size
]
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
int
H
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
...
@@ -99,8 +101,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -99,8 +101,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
};
};
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
weight_t
->
numel
()
*
sizeof
(
float
));
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
head_number
,
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
three
,
head_
size
);
head_
number
,
head_size
,
hidden_in
);
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
void
*>
(
weight_data
),
static_cast
<
int32_t
>
(
weight_t
->
numel
())};
static_cast
<
int32_t
>
(
weight_t
->
numel
())};
...
@@ -130,7 +132,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -130,7 +132,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
int
var_seqlen
=
1
;
int
var_seqlen
=
1
;
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden
_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
...
@@ -185,7 +187,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -185,7 +187,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
n
,
weight
.
get
(),
bias
.
get
());
n
,
weight
.
get
(),
bias
.
get
());
auto
*
fc_out
=
fc_layer
->
getOutput
(
0
);
auto
*
fc_out
=
fc_layer
->
getOutput
(
0
);
// add qkv to context
// add qkv to context
int
head_size
=
all_head_size
/
head_number
;
int
head_size
=
hidden_out
/
head_number
;
float
scale
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"alpha"
));
float
scale
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"alpha"
));
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
...
@@ -194,7 +196,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -194,7 +196,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
bool
with_fp16
=
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
DynamicPluginTensorRT
*
plugin
=
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden
,
head_number
,
new
plugin
::
QkvToContextPluginDynamic
(
hidden
_in
,
head_number
,
head_size
,
scale
,
with_fp16
);
head_size
,
scale
,
with_fp16
);
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
2
,
plugin
);
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
2
,
plugin
);
}
}
...
...
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
浏览文件 @
28c36d86
...
@@ -54,7 +54,12 @@ nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
...
@@ -54,7 +54,12 @@ nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
output
.
nbDims
++
;
for
(
int
i
=
output
.
nbDims
-
1
;
i
>
1
;
i
--
)
{
output
.
d
[
i
]
=
inputs
[
0
].
d
[
i
-
1
];
}
auto
one
=
expr_builder
.
constant
(
1
);
auto
one
=
expr_builder
.
constant
(
1
);
output
.
d
[
1
]
=
one
;
output
.
d
[
0
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
output
.
d
[
0
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
inputs
[
1
].
d
[
0
],
*
one
);
*
inputs
[
1
].
d
[
0
],
*
one
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录