Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b7a86e92
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b7a86e92
编写于
8月 19, 2020
作者:
Z
Zhaolong Xing
提交者:
GitHub
8月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dy shape bug in trt7.1 (#26273)
test=develop
上级
c45481d7
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
37 addition
and
5 deletion
+37
-5
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+6
-1
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
...inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
+11
-1
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
.../inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
+6
-1
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
+6
-0
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
+4
-1
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
...luid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
+4
-1
未找到文件。
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
b7a86e92
...
...
@@ -83,7 +83,12 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
}
else
if
(
shape
.
size
()
==
3UL
)
{
return
nvinfer1
::
Dims3
(
shape
[
0
],
shape
[
1
],
shape
[
2
]);
}
return
nvinfer1
::
Dims4
(
shape
[
0
],
shape
[
1
],
1
,
1
);
nvinfer1
::
Dims
dims
;
dims
.
nbDims
=
shape
.
size
();
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
i
++
)
{
dims
.
d
[
i
]
=
shape
[
i
];
}
return
dims
;
}
}
}
// NOLINT
...
...
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
浏览文件 @
b7a86e92
...
...
@@ -76,6 +76,16 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
return
ret
;
}
template
<
typename
T
>
void
EmbEltwiseLayernormPluginDynamic
<
T
>::
terminate
()
{
for
(
auto
ptr
:
embs_gpu_
)
{
if
(
ptr
)
cudaFree
(
ptr
);
}
if
(
bias_gpu_
)
cudaFree
(
bias_gpu_
);
if
(
scale_gpu_
)
cudaFree
(
scale_gpu_
);
}
template
<
typename
T
>
bool
EmbEltwiseLayernormPluginDynamic
<
T
>::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
...
...
@@ -153,7 +163,7 @@ int EmbEltwiseLayernormPluginDynamic<T>::enqueue(
int64_t
*
emb_ptr_gpu_d
=
emb_ptr_tensor
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id
));
std
::
vector
<
int64
_t
>
in_ptr
,
emb_ptr
;
std
::
vector
<
uintptr
_t
>
in_ptr
,
emb_ptr
;
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
in_ptr
.
push_back
(
reinterpret_cast
<
uintptr_t
>
(
inputs
[
i
]));
emb_ptr
.
push_back
(
reinterpret_cast
<
uintptr_t
>
(
embs_gpu_
[
i
]));
...
...
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
浏览文件 @
b7a86e92
...
...
@@ -81,9 +81,13 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
EmbEltwiseLayernormPluginDynamic
(
auto
ptr
=
new
EmbEltwiseLayernormPluginDynamic
(
embs_
,
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
ptr
->
embs_gpu_
=
embs_gpu_
;
ptr
->
bias_gpu_
=
bias_gpu_
;
ptr
->
scale_gpu_
=
scale_gpu_
;
return
ptr
;
}
const
char
*
getPluginType
()
const
override
{
...
...
@@ -111,6 +115,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
return
sum_num
;
}
void
terminate
()
override
;
void
serialize
(
void
*
buffer
)
const
override
{
// SerializeValue(&buffer, with_fp16_);
SerializeValue
(
&
buffer
,
emb_sizes_
);
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
浏览文件 @
b7a86e92
...
...
@@ -80,6 +80,12 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
#if IS_TRT_VERSION_GE(6000)
void
PReluPluginDynamic
::
terminate
()
{
if
(
p_gpu_weight_
)
{
cudaFree
(
p_gpu_weight_
);
}
}
int
PReluPluginDynamic
::
initialize
()
{
cudaMalloc
(
&
p_gpu_weight_
,
sizeof
(
float
)
*
weight_
.
size
());
cudaMemcpy
(
p_gpu_weight_
,
weight_
.
data
(),
weight_
.
size
()
*
sizeof
(
float
),
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
浏览文件 @
b7a86e92
...
...
@@ -102,12 +102,15 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
}
~
PReluPluginDynamic
()
{
cudaFree
(
p_gpu_weight_
);
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
PReluPluginDynamic
(
weight_
.
data
(),
weight_
.
size
(),
mode_
);
auto
ptr
=
new
PReluPluginDynamic
(
weight_
.
data
(),
weight_
.
size
(),
mode_
);
ptr
->
p_gpu_weight_
=
p_gpu_weight_
;
return
ptr
;
}
const
char
*
getPluginType
()
const
override
{
return
"prelu_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
;
void
terminate
()
override
;
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
...
...
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
浏览文件 @
b7a86e92
...
...
@@ -51,8 +51,11 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
SkipLayerNormPluginDynamic
(
auto
ptr
=
new
SkipLayerNormPluginDynamic
(
bias_
.
data
(),
scale_
.
data
(),
bias_size_
,
scale_size_
,
eps_
,
ban_fp16_
);
ptr
->
bias_gpu_
=
bias_gpu_
;
ptr
->
scale_gpu_
=
bias_gpu_
;
return
ptr
;
}
const
char
*
getPluginType
()
const
override
{
return
"skip_layernorm_plugin"
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录