Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ac0553a0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ac0553a0
编写于
8月 15, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
8月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fused_embedding_eltwise_layernorm_op and skip_layernorm_op support fp16 (#44969)
上级
3512bf11
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
721 addition
and
329 deletion
+721
-329
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-1
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+6
-0
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+30
-13
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+27
-9
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+62
-1
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+4
-0
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
...inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
+12
-37
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
.../inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
+138
-91
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu
...uid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu
+94
-66
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
...luid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
+204
-32
paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h
+11
-8
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h
.../api/trt_dynamic_shape_ernie_serialize_deserialize_test.h
+2
-1
paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu
...d/operators/fused/fused_embedding_eltwise_layernorm_op.cu
+42
-13
paddle/fluid/operators/fused/skip_layernorm_op.cu
paddle/fluid/operators/fused/skip_layernorm_op.cu
+37
-9
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+44
-43
paddle/fluid/operators/math/bert_encoder_functor.h
paddle/fluid/operators/math/bert_encoder_functor.h
+5
-5
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+2
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
ac0553a0
...
@@ -166,7 +166,6 @@ if(WITH_TENSORRT)
...
@@ -166,7 +166,6 @@ if(WITH_TENSORRT)
pass_library
(
trt_map_matmul_to_mul_pass inference
)
pass_library
(
trt_map_matmul_to_mul_pass inference
)
pass_library
(
trt_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
trt_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
trt_multihead_matmul_fuse_pass inference
)
pass_library
(
trt_multihead_matmul_fuse_pass inference
)
pass_library
(
trt_skip_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_skip_layernorm_fuse_pass inference
)
pass_library
(
preln_skip_layernorm_fuse_pass inference
)
pass_library
(
set_transformer_input_convert_pass inference
)
pass_library
(
set_transformer_input_convert_pass inference
)
...
@@ -177,6 +176,7 @@ endif()
...
@@ -177,6 +176,7 @@ endif()
if
(
WITH_GPU OR WITH_ROCM
)
if
(
WITH_GPU OR WITH_ROCM
)
pass_library
(
cudnn_placement_pass base DEPS placement_pass_base
)
pass_library
(
cudnn_placement_pass base DEPS placement_pass_base
)
pass_library
(
embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
trt_skip_layernorm_fuse_pass inference
)
endif
()
endif
()
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
ac0553a0
...
@@ -165,12 +165,17 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
...
@@ -165,12 +165,17 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"gpu_cpu_map_matmul_v2_to_matmul_pass"
,
"gpu_cpu_map_matmul_v2_to_matmul_pass"
,
"fc_fuse_pass"
,
"fc_fuse_pass"
,
"fc_elementwise_layernorm_fuse_pass"
,
"fc_elementwise_layernorm_fuse_pass"
,
"embedding_eltwise_layernorm_fuse_pass"
,
"trt_skip_layernorm_fuse_pass"
,
"runtime_context_cache_pass"
,
};
};
const
std
::
vector
<
std
::
string
>
kTrtLowerPrecisionPasses
{
const
std
::
vector
<
std
::
string
>
kTrtLowerPrecisionPasses
{
"simplify_with_basic_ops_pass"
,
"simplify_with_basic_ops_pass"
,
// "conv_bn_fuse_pass",
// "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass",
"trt_embedding_eltwise_layernorm_fuse_pass"
,
"trt_skip_layernorm_fuse_pass"
,
"trt_map_matmul_v2_to_mul_pass"
,
"trt_map_matmul_v2_to_mul_pass"
,
"trt_map_matmul_v2_to_matmul_pass"
,
"trt_map_matmul_v2_to_matmul_pass"
,
"trt_map_matmul_to_mul_pass"
,
"trt_map_matmul_to_mul_pass"
,
...
@@ -186,6 +191,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
...
@@ -186,6 +191,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
"trt_skip_layernorm_fuse_pass"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"gpu_cpu_squeeze2_matmul_fuse_pass"
,
//
"gpu_cpu_squeeze2_matmul_fuse_pass"
,
//
"gpu_cpu_reshape2_matmul_fuse_pass"
,
//
"gpu_cpu_reshape2_matmul_fuse_pass"
,
//
...
...
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
ac0553a0
...
@@ -133,6 +133,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -133,6 +133,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
return
weight
;
return
weight
;
};
};
auto
GetFp16Weight
=
[
&
](
const
std
::
string
&
var_name
,
framework
::
DDim
*
dim
)
->
TensorRTEngine
::
Weight
{
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
*
dim
=
temp_tensor
->
dims
();
auto
weight
=
engine_
->
GetFp16TrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
auto
GetFp32Weight
=
[
&
](
const
std
::
string
&
var_name
,
auto
GetFp32Weight
=
[
&
](
const
std
::
string
&
var_name
,
framework
::
DDim
*
dim
)
->
TensorRTEngine
::
Weight
{
framework
::
DDim
*
dim
)
->
TensorRTEngine
::
Weight
{
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
...
@@ -141,7 +150,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -141,7 +150,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto
weight
=
engine_
->
GetFp32TrtWeight
(
var_name
,
*
temp_tensor
);
auto
weight
=
engine_
->
GetFp32TrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
return
weight
;
};
};
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
int
hidden
=
0
;
int
hidden
=
0
;
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
framework
::
DDim
emb_dims
;
framework
::
DDim
emb_dims
;
...
@@ -149,7 +158,11 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -149,7 +158,11 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
if
(
flag_varseqlen
)
{
if
(
flag_varseqlen
)
{
weight
=
GetWeight
(
emb_names
[
i
],
&
emb_dims
);
weight
=
GetWeight
(
emb_names
[
i
],
&
emb_dims
);
}
else
{
}
else
{
weight
=
GetFp32Weight
(
emb_names
[
i
],
&
emb_dims
);
if
(
with_fp16
)
{
weight
=
GetFp16Weight
(
emb_names
[
i
],
&
emb_dims
);
}
else
{
weight
=
GetFp32Weight
(
emb_names
[
i
],
&
emb_dims
);
}
}
}
input_embs
.
push_back
(
weight
.
get
());
input_embs
.
push_back
(
weight
.
get
());
emb_sizes
.
push_back
(
weight
.
get
().
count
);
emb_sizes
.
push_back
(
weight
.
get
().
count
);
...
@@ -167,8 +180,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -167,8 +180,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
bias_weight
=
GetWeight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
bias_weight
=
GetWeight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
scale_weight
=
GetWeight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
scale_weight
=
GetWeight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
}
else
{
}
else
{
bias_weight
=
GetFp32Weight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
if
(
with_fp16
)
{
scale_weight
=
GetFp32Weight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
bias_weight
=
GetFp16Weight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
scale_weight
=
GetFp16Weight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
}
else
{
bias_weight
=
GetFp32Weight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
scale_weight
=
GetFp32Weight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
}
}
}
int64_t
bias_size
=
phi
::
product
(
bias_dims
);
int64_t
bias_size
=
phi
::
product
(
bias_dims
);
...
@@ -282,21 +302,18 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -282,21 +302,18 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
test_mode
);
test_mode
);
}
}
}
else
{
}
else
{
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
float
eps
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
float
eps
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
std
::
vector
<
float
*>
input_embs_data
;
std
::
vector
<
void
*>
input_embs_data
;
for
(
size_t
i
=
0
;
i
<
input_embs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_embs
.
size
();
++
i
)
{
input_embs_data
.
push_back
(
const_cast
<
float
*>
(
input_embs_data
.
push_back
(
const_cast
<
void
*>
(
static_cast
<
const
float
*>
(
input_embs
[
i
].
values
)));
reinterpret_cast
<
const
void
*>
(
input_embs
[
i
].
values
)));
}
}
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
(
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
(
input_embs_data
,
input_embs_data
,
const_cast
<
float
*>
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias_weight
.
get
().
values
)),
static_cast
<
const
float
*>
(
bias_weight
.
get
().
values
)),
const_cast
<
void
*>
(
const_cast
<
float
*>
(
static_cast
<
const
void
*>
(
scale_weight
.
get
().
values
)),
static_cast
<
const
float
*>
(
scale_weight
.
get
().
values
)),
emb_sizes
,
emb_sizes
,
bias_size
,
bias_size
,
scale_size
,
scale_size
,
...
...
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
ac0553a0
...
@@ -150,6 +150,15 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -150,6 +150,15 @@ class SkipLayerNormOpConverter : public OpConverter {
layer
=
plugin_layer
;
layer
=
plugin_layer
;
}
}
}
else
{
}
else
{
auto
GetFp16Weight
=
[
&
](
const
std
::
string
&
arg_name
)
->
TensorRTEngine
::
Weight
{
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
weight
=
engine_
->
GetFp16TrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
auto
GetFp32Weight
=
auto
GetFp32Weight
=
[
&
](
const
std
::
string
&
arg_name
)
->
TensorRTEngine
::
Weight
{
[
&
](
const
std
::
string
&
arg_name
)
->
TensorRTEngine
::
Weight
{
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
...
@@ -159,20 +168,29 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -159,20 +168,29 @@ class SkipLayerNormOpConverter : public OpConverter {
return
weight
;
return
weight
;
};
};
auto
bias_weight
=
GetFp32Weight
(
"Bias"
).
get
();
// bool with_fp16 = engine_->WithFp16() &&
auto
scale_weight
=
GetFp32Weight
(
"Scale"
).
get
();
// !engine_->disable_trt_plugin_fp16() &&
// (input1->getType() == nvinfer1::DataType::kHALF);
bool
with_fp16
=
false
;
TensorRTEngine
::
Weight
bias_weight
,
scale_weight
;
if
(
with_fp16
)
{
bias_weight
=
GetFp16Weight
(
"Bias"
);
scale_weight
=
GetFp16Weight
(
"Scale"
);
}
else
{
bias_weight
=
GetFp32Weight
(
"Bias"
);
scale_weight
=
GetFp32Weight
(
"Scale"
);
}
float
eps
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
float
eps
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
// bool with_fp16 =
// engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
bool
with_fp16
=
false
;
plugin
::
SkipLayerNormPluginDynamic
*
plugin
=
plugin
::
SkipLayerNormPluginDynamic
*
plugin
=
new
plugin
::
SkipLayerNormPluginDynamic
(
new
plugin
::
SkipLayerNormPluginDynamic
(
static_cast
<
const
float
*>
(
bias_weight
.
values
),
const_cast
<
void
*>
(
static_cast
<
const
float
*>
(
scale_weight
.
values
),
static_cast
<
const
void
*>
(
bias_weight
.
get
().
values
)),
bias_weight
.
count
,
const_cast
<
void
*>
(
scale_weight
.
count
,
static_cast
<
const
void
*>
(
scale_weight
.
get
().
values
)),
bias_weight
.
get
().
count
,
scale_weight
.
get
().
count
,
eps
,
eps
,
with_fp16
);
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
ac0553a0
...
@@ -31,7 +31,7 @@ namespace inference {
...
@@ -31,7 +31,7 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
void
TensorRTEngine
::
Weight
::
SetDataType
(
phi
::
DataType
type
)
{
void
TensorRTEngine
::
Weight
::
SetDataType
(
phi
::
DataType
type
)
{
nvinfer1
::
DataType
nv_type
;
nvinfer1
::
DataType
nv_type
=
nvinfer1
::
DataType
::
kFLOAT
;
switch
(
type
)
{
switch
(
type
)
{
case
phi
::
DataType
::
FLOAT32
:
case
phi
::
DataType
::
FLOAT32
:
nv_type
=
nvinfer1
::
DataType
::
kFLOAT
;
nv_type
=
nvinfer1
::
DataType
::
kFLOAT
;
...
@@ -455,6 +455,67 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
...
@@ -455,6 +455,67 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
runtime_batch_
=
batch_size
;
runtime_batch_
=
batch_size
;
}
}
// Note: Only for support plugin.
TensorRTEngine
::
Weight
TensorRTEngine
::
GetFp16TrtWeight
(
const
std
::
string
&
name
,
const
framework
::
Tensor
&
weight_tensor
)
{
static
int
name_suffix_counter
=
0
;
std
::
string
name_suffix
=
std
::
to_string
(
name_suffix_counter
);
std
::
string
splitter
=
"__"
;
std
::
string
name_with_suffix
=
name
+
splitter
+
name_suffix
;
platform
::
CPUPlace
cpu_place
;
PADDLE_ENFORCE_EQ
(
weight_map
.
count
(
name_with_suffix
),
0
,
platform
::
errors
::
AlreadyExists
(
"The weight named %s is set into the weight map "
"twice in TRT OP converter."
,
name_with_suffix
));
weight_map
[
name_with_suffix
].
reset
(
new
framework
::
Tensor
());
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
.
dims
());
TensorRTEngine
::
Weight
weight
;
weight
.
SetCount
(
weight_tensor
.
numel
());
weight
.
SetDataType
(
nvinfer1
::
DataType
::
kHALF
);
// weight_tensor.dims().;
// if trt not support dtype, we need to cast to fp16.
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
framework
::
Tensor
bf16_tensor
;
bf16_tensor
.
clear
();
paddle
::
framework
::
TensorCopySync
(
weight_tensor
,
platform
::
CPUPlace
(),
&
bf16_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
.
dims
());
auto
*
fp16_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
auto
*
bf16_data
=
bf16_tensor
.
mutable_data
<
bfloat16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
fp16_data
[
i
]
=
static_cast
<
float16
>
(
bf16_data
[
i
]);
}
}
else
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
framework
::
Tensor
fp32_tensor
;
fp32_tensor
.
clear
();
paddle
::
framework
::
TensorCopySync
(
weight_tensor
,
platform
::
CPUPlace
(),
&
fp32_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
.
dims
());
auto
*
fp16_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
auto
*
fp32_data
=
fp32_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
fp16_data
[
i
]
=
static_cast
<
float16
>
(
fp32_data
[
i
]);
}
}
else
{
paddle
::
framework
::
TensorCopySync
(
weight_tensor
,
cpu_place
,
weight_map
[
name_with_suffix
].
get
());
}
weight
.
SetValues
(
weight_map
[
name_with_suffix
]
->
data
());
name_suffix_counter
+=
1
;
return
weight
;
}
// Note: Only for support plugin.
TensorRTEngine
::
Weight
TensorRTEngine
::
GetFp32TrtWeight
(
TensorRTEngine
::
Weight
TensorRTEngine
::
GetFp32TrtWeight
(
const
std
::
string
&
name
,
const
framework
::
Tensor
&
weight_tensor
)
{
const
std
::
string
&
name
,
const
framework
::
Tensor
&
weight_tensor
)
{
static
int
name_suffix_counter
=
0
;
static
int
name_suffix_counter
=
0
;
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
ac0553a0
...
@@ -421,6 +421,10 @@ class TensorRTEngine {
...
@@ -421,6 +421,10 @@ class TensorRTEngine {
quant_dynamic_range_
[
tensor
]
=
range
;
quant_dynamic_range_
[
tensor
]
=
range
;
}
}
// Get fp16 trt weight. If src weight is not fp16, we will cast.
Weight
GetFp16TrtWeight
(
const
std
::
string
&
name
,
const
framework
::
Tensor
&
weight_tensor
);
// Get fp32 trt weight. If src weight is not fp32, we will cast.
// Get fp32 trt weight. If src weight is not fp32, we will cast.
Weight
GetFp32TrtWeight
(
const
std
::
string
&
name
,
Weight
GetFp32TrtWeight
(
const
std
::
string
&
name
,
const
framework
::
Tensor
&
weight_tensor
);
const
framework
::
Tensor
&
weight_tensor
);
...
...
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
浏览文件 @
ac0553a0
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <cassert>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <cub/cub.cuh> // NOLINT
#include <type_traits>
#include <vector>
#include <vector>
#include "glog/logging.h"
#include "glog/logging.h"
...
@@ -32,12 +33,6 @@ namespace plugin {
...
@@ -32,12 +33,6 @@ namespace plugin {
// Dynamic shape plugin requires TRT version greater than 6.0.
// Dynamic shape plugin requires TRT version greater than 6.0.
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
template
<
typename
T
>
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::~
EmbEltwiseLayernormPluginDynamicImpl
()
{}
inline
half
fp32tofp16
(
float
x
)
{
return
static_cast
<
half
>
(
x
);
}
template
<
typename
T
>
template
<
typename
T
>
void
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
shareGPUData
(
void
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamicImplBase
*
anthor
)
{
const
EmbEltwiseLayernormPluginDynamicImplBase
*
anthor
)
{
...
@@ -62,36 +57,24 @@ int EmbEltwiseLayernormPluginDynamicImpl<T>::initialize() {
...
@@ -62,36 +57,24 @@ int EmbEltwiseLayernormPluginDynamicImpl<T>::initialize() {
embs_gpu_
.
resize
(
embs_
.
size
());
embs_gpu_
.
resize
(
embs_
.
size
());
for
(
int
i
=
0
;
i
<
embs_
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
embs_
.
size
();
i
++
)
{
if
(
embs_
[
i
])
{
if
(
embs_
[
i
])
{
T
*
host_ptr
;
T
*
host_ptr
=
embs_
[
i
]
;
auto
size
=
emb_sizes_
[
i
];
auto
size
=
emb_sizes_
[
i
];
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
host_ptr
=
new
T
[
size
];
std
::
transform
(
embs_
[
i
],
(
embs_
[
i
]
+
size
),
host_ptr
,
fp32tofp16
);
}
else
{
host_ptr
=
reinterpret_cast
<
T
*>
(
embs_
[
i
]);
}
cudaMalloc
(
&
embs_gpu_
[
i
],
sizeof
(
T
)
*
size
);
cudaMalloc
(
&
embs_gpu_
[
i
],
sizeof
(
T
)
*
size
);
cudaMemcpy
(
cudaMemcpy
(
embs_gpu_
[
i
],
host_ptr
,
size
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
embs_gpu_
[
i
],
host_ptr
,
size
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
delete
[]
host_ptr
;
}
}
}
}
}
if
(
bias_
)
{
if
(
bias_
)
{
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
float
)
*
bias_size_
);
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
T
)
*
bias_size_
);
cudaMemcpy
(
cudaMemcpy
(
bias_gpu_
,
bias_
,
bias_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
bias_gpu_
,
bias_
,
bias_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
}
if
(
scale_
)
{
if
(
scale_
)
{
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
float
)
*
scale_size_
);
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
T
)
*
scale_size_
);
cudaMemcpy
(
scale_gpu_
,
cudaMemcpy
(
scale_
,
scale_gpu_
,
scale_
,
scale_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
scale_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
}
int
input_num
=
embs_
.
size
();
int
input_num
=
embs_
.
size
();
...
@@ -239,22 +222,14 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
...
@@ -239,22 +222,14 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
"The EmbEltwiseLayerNorm's output should be one"
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs."
,
"but it's (%d) outputs."
,
nb_outputs
));
nb_outputs
));
PADDLE_ENFORCE_EQ
(
nb_outputs
,
int
all_nums
=
nb_inputs
+
nb_outputs
;
1
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs."
,
nb_outputs
));
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
pos
,
pos
,
nb_inputs
+
nb_output
s
,
all_num
s
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
"num(%d) of the input and the output."
,
pos
,
pos
,
nb_inputs
+
nb_outputs
));
all_nums
));
int
all_nums
=
nb_inputs
+
nb_outputs
;
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
if
(
desc
.
format
!=
nvinfer1
::
TensorFormat
::
kLINEAR
)
{
if
(
desc
.
format
!=
nvinfer1
::
TensorFormat
::
kLINEAR
)
{
return
false
;
return
false
;
...
@@ -269,7 +244,7 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
...
@@ -269,7 +244,7 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
return
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
return
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
desc
.
dims
.
d
[
0
]
==
prev
.
dims
.
d
[
0
]
&&
desc
.
dims
.
d
[
1
]
==
prev
.
dims
.
d
[
1
];
desc
.
dims
.
d
[
0
]
==
prev
.
dims
.
d
[
0
]
&&
desc
.
dims
.
d
[
1
]
==
prev
.
dims
.
d
[
1
];
}
}
// output
if
(
pos
==
all_nums
-
1
)
{
if
(
pos
==
all_nums
-
1
)
{
if
(
with_fp16_
==
false
)
{
if
(
with_fp16_
==
false
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
;
return
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
;
...
@@ -288,7 +263,7 @@ nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
...
@@ -288,7 +263,7 @@ nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
index
,
index
,
0
,
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only has one
in
put, so the "
"The EmbEltwiseLayernorm Plugin only has one
out
put, so the "
"index value should be 0, but get %d."
,
"index value should be 0, but get %d."
,
index
));
index
));
if
(
with_fp16_
)
if
(
with_fp16_
)
...
...
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
浏览文件 @
ac0553a0
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <algorithm>
#include <algorithm>
#include <cstddef>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -49,9 +50,9 @@ template <typename T>
...
@@ -49,9 +50,9 @@ template <typename T>
class
EmbEltwiseLayernormPluginDynamicImpl
class
EmbEltwiseLayernormPluginDynamicImpl
:
public
EmbEltwiseLayernormPluginDynamicImplBase
{
:
public
EmbEltwiseLayernormPluginDynamicImplBase
{
public:
public:
explicit
EmbEltwiseLayernormPluginDynamicImpl
(
std
::
vector
<
float
*>
input_embs
,
explicit
EmbEltwiseLayernormPluginDynamicImpl
(
std
::
vector
<
T
*>
input_embs
,
float
*
bias
,
T
*
bias
,
float
*
scale
,
T
*
scale
,
std
::
vector
<
int
>
emb_sizes
,
std
::
vector
<
int
>
emb_sizes
,
int
bias_size
,
int
bias_size
,
int
scale_size
,
int
scale_size
,
...
@@ -66,7 +67,7 @@ class EmbEltwiseLayernormPluginDynamicImpl
...
@@ -66,7 +67,7 @@ class EmbEltwiseLayernormPluginDynamicImpl
hidden_size_
(
hidden_size
),
hidden_size_
(
hidden_size
),
eps_
(
eps
)
{}
eps_
(
eps
)
{}
~
EmbEltwiseLayernormPluginDynamicImpl
()
;
~
EmbEltwiseLayernormPluginDynamicImpl
()
{}
int
initialize
();
int
initialize
();
void
terminate
();
void
terminate
();
...
@@ -79,13 +80,13 @@ class EmbEltwiseLayernormPluginDynamicImpl
...
@@ -79,13 +80,13 @@ class EmbEltwiseLayernormPluginDynamicImpl
void
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamicImplBase
*
anthor
);
void
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamicImplBase
*
anthor
);
private:
private:
std
::
vector
<
float
*>
embs_
;
std
::
vector
<
T
*>
embs_
;
float
*
bias_
{
nullptr
};
T
*
bias_
{
nullptr
};
float
*
scale_
{
nullptr
};
T
*
scale_
{
nullptr
};
// data on devices
// data on devices
float
*
bias_gpu_
{
nullptr
};
T
*
bias_gpu_
{
nullptr
};
float
*
scale_gpu_
{
nullptr
};
T
*
scale_gpu_
{
nullptr
};
std
::
vector
<
T
*>
embs_gpu_
;
std
::
vector
<
T
*>
embs_gpu_
;
std
::
vector
<
int
>
emb_sizes_
;
std
::
vector
<
int
>
emb_sizes_
;
...
@@ -101,9 +102,9 @@ class EmbEltwiseLayernormPluginDynamicImpl
...
@@ -101,9 +102,9 @@ class EmbEltwiseLayernormPluginDynamicImpl
class
EmbEltwiseLayernormPluginDynamic
:
public
DynamicPluginTensorRT
{
class
EmbEltwiseLayernormPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
public:
explicit
EmbEltwiseLayernormPluginDynamic
(
std
::
vector
<
float
*>
input_embs
,
explicit
EmbEltwiseLayernormPluginDynamic
(
std
::
vector
<
void
*>
input_embs
,
float
*
bias
,
void
*
bias
,
float
*
scale
,
void
*
scale
,
std
::
vector
<
int
>
emb_sizes
,
std
::
vector
<
int
>
emb_sizes
,
int
bias_size
,
int
bias_size
,
int
scale_size
,
int
scale_size
,
...
@@ -123,14 +124,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -123,14 +124,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG
(
1
)
<<
"TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp16"
;
VLOG
(
1
)
<<
"TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp16"
;
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
half
>
(
embs_
,
instantiateImpl
<
half
>
();
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
#else
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"The Ernie(Bert) tensorRT plugin should be "
...
@@ -141,63 +135,74 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -141,63 +135,74 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
#endif
#endif
}
else
{
}
else
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp32"
;
VLOG
(
1
)
<<
"TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp32"
;
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
float
>
(
embs_
,
instantiateImpl
<
float
>
();
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
}
}
}
}
EmbEltwiseLayernormPluginDynamic
(
void
const
*
serial_data
,
EmbEltwiseLayernormPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
size_t
serial_length
)
:
own_host_buff_
(
true
)
{
:
own_host_buff_
(
true
)
{
// the first var is with_fp16, we will use it.
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
emb_sizes_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
emb_sizes_
);
embs_
.
resize
(
emb_sizes_
.
size
());
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
auto
ptr
=
new
float
[
size
];
memcpy
(
ptr
,
serial_data
,
sizeof
(
float
)
*
size
);
embs_
[
i
]
=
ptr
;
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
emb_sizes_
[
i
]
*
sizeof
(
float
);
serial_length
-=
emb_sizes_
[
i
]
*
sizeof
(
float
);
}
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
if
(
bias_size_
)
{
embs_
.
resize
(
emb_sizes_
.
size
());
bias_
=
new
float
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
float
)
*
bias_size_
);
if
(
with_fp16_
)
{
}
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
float
);
auto
size
=
emb_sizes_
[
i
];
serial_length
-=
bias_size_
*
sizeof
(
float
);
auto
ptr
=
new
half
[
size
];
memcpy
(
ptr
,
serial_data
,
sizeof
(
half
)
*
size
);
embs_
[
i
]
=
ptr
;
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
size
*
sizeof
(
half
);
serial_length
-=
size
*
sizeof
(
half
);
}
if
(
bias_size_
)
{
bias_
=
new
half
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
half
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
half
);
serial_length
-=
bias_size_
*
sizeof
(
half
);
if
(
scale_size_
)
{
if
(
scale_size_
)
{
scale_
=
new
float
[
scale_size_
];
scale_
=
new
half
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
float
)
*
scale_size_
);
memcpy
(
scale_
,
serial_data
,
sizeof
(
half
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
half
);
serial_length
-=
scale_size_
*
sizeof
(
half
);
}
else
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
auto
ptr
=
new
float
[
size
];
memcpy
(
ptr
,
serial_data
,
sizeof
(
float
)
*
size
);
embs_
[
i
]
=
ptr
;
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
size
*
sizeof
(
float
);
serial_length
-=
size
*
sizeof
(
float
);
}
if
(
bias_size_
)
{
bias_
=
new
float
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
float
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
float
);
serial_length
-=
bias_size_
*
sizeof
(
float
);
if
(
scale_size_
)
{
scale_
=
new
float
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
float
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
float
);
serial_length
-=
scale_size_
*
sizeof
(
float
);
}
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
float
);
serial_length
-=
scale_size_
*
sizeof
(
float
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
hidden_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
hidden_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#ifdef TRT_PLUGIN_FP16_AVALIABLE
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
half
>
(
embs_
,
instantiateImpl
<
half
>
();
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
#else
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"The Ernie(Bert) tensorRT plugin should be "
...
@@ -207,14 +212,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -207,14 +212,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
"AnalysisConfig::Precision::kFloat32, false, false) "
));
"AnalysisConfig::Precision::kFloat32, false, false) "
));
#endif
#endif
}
else
{
}
else
{
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
float
>
(
embs_
,
instantiateImpl
<
float
>
();
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
}
}
}
}
...
@@ -241,44 +239,68 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -241,44 +239,68 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
int
sum_num
=
0
;
int
sum_num
=
0
;
sum_num
+=
SerializedSize
(
with_fp16_
);
sum_num
+=
SerializedSize
(
emb_sizes_
);
sum_num
+=
SerializedSize
(
emb_sizes_
);
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
if
(
with_fp16_
)
{
sum_num
+=
emb_sizes_
[
i
]
*
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
sum_num
+=
emb_sizes_
[
i
]
*
sizeof
(
half
);
}
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
half
);
}
else
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
sum_num
+=
emb_sizes_
[
i
]
*
sizeof
(
float
);
}
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
float
);
}
}
sum_num
+=
SerializedSize
(
bias_size_
);
sum_num
+=
SerializedSize
(
bias_size_
);
sum_num
+=
SerializedSize
(
scale_size_
);
sum_num
+=
SerializedSize
(
scale_size_
);
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
float
);
sum_num
+=
SerializedSize
(
hidden_size_
);
sum_num
+=
SerializedSize
(
hidden_size_
);
sum_num
+=
SerializedSize
(
eps_
);
sum_num
+=
SerializedSize
(
eps_
);
sum_num
+=
SerializedSize
(
with_fp16_
);
return
sum_num
;
return
sum_num
;
}
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
// the first var is for with_fp16, we will use it later;
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
emb_sizes_
);
SerializeValue
(
&
buffer
,
emb_sizes_
);
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
SerializeValue
(
&
buffer
,
embs_
[
i
][
j
]);
}
}
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
if
(
with_fp16_
)
{
SerializeValue
(
&
buffer
,
bias_
[
i
]);
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
}
auto
size
=
emb_sizes_
[
i
];
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
embs_
[
i
])[
j
]);
}
}
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
scale_
[
i
]);
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
scale_
)[
i
]);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
embs_
[
i
])[
j
]);
}
}
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
scale_
)[
i
]);
}
}
}
SerializeValue
(
&
buffer
,
hidden_size_
);
SerializeValue
(
&
buffer
,
hidden_size_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
...
@@ -317,21 +339,28 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -317,21 +339,28 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
void
destroy
()
TRT_NOEXCEPT
override
{
void
destroy
()
TRT_NOEXCEPT
override
{
if
(
own_host_buff_
)
{
if
(
own_host_buff_
)
{
for
(
auto
ptr
:
embs_
)
{
if
(
with_fp16_
)
{
delete
[]
ptr
;
for
(
auto
ptr
:
embs_
)
{
delete
[]
reinterpret_cast
<
half
*>
(
ptr
);
}
delete
[]
reinterpret_cast
<
half
*>
(
bias_
);
delete
[]
reinterpret_cast
<
half
*>
(
scale_
);
}
else
{
for
(
auto
ptr
:
embs_
)
{
delete
[]
reinterpret_cast
<
float
*>
(
ptr
);
}
delete
[]
reinterpret_cast
<
float
*>
(
bias_
);
delete
[]
reinterpret_cast
<
float
*>
(
scale_
);
}
}
delete
[]
bias_
;
delete
[]
scale_
;
}
}
delete
impl_
;
delete
impl_
;
delete
this
;
delete
this
;
}
}
private:
private:
std
::
vector
<
float
*>
embs_
;
std
::
vector
<
void
*>
embs_
;
float
*
bias_
;
void
*
bias_
{
nullptr
}
;
float
*
scale_
;
void
*
scale_
{
nullptr
}
;
std
::
vector
<
int
>
emb_sizes_
;
std
::
vector
<
int
>
emb_sizes_
;
int
bias_size_
;
int
bias_size_
;
...
@@ -345,6 +374,24 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -345,6 +374,24 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
void
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamic
*
anthor
)
{
void
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamic
*
anthor
)
{
impl_
->
shareGPUData
(
anthor
->
impl_
);
impl_
->
shareGPUData
(
anthor
->
impl_
);
}
}
template
<
typename
U
>
void
instantiateImpl
()
{
std
::
vector
<
U
*>
embs
;
embs
.
resize
(
embs_
.
size
());
for
(
size_t
i
=
0
;
i
<
embs_
.
size
();
++
i
)
{
embs
[
i
]
=
reinterpret_cast
<
U
*>
(
embs_
[
i
]);
}
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
U
>
(
embs
,
reinterpret_cast
<
U
*>
(
bias_
),
reinterpret_cast
<
U
*>
(
scale_
),
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
}
};
};
class
EmbEltwiseLayernormPluginDynamicCreator
class
EmbEltwiseLayernormPluginDynamicCreator
...
...
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu
浏览文件 @
ac0553a0
...
@@ -31,31 +31,61 @@ namespace plugin {
...
@@ -31,31 +31,61 @@ namespace plugin {
// Dynamic Plugin below.
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
int
SkipLayerNormPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
template
<
typename
T
>
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
float
)
*
bias_size_
);
void
SkipLayerNormPluginDynamicImpl
<
T
>::
shareGPUData
(
cudaMemcpy
(
bias_gpu_
,
const
SkipLayerNormPluginDynamicImplBase
*
anthor
)
{
bias_
.
data
(),
auto
*
ptr
=
dynamic_cast
<
const
SkipLayerNormPluginDynamicImpl
<
T
>
*>
(
anthor
);
bias_size_
*
sizeof
(
float
),
if
(
!
ptr
->
is_initialized_
)
{
cudaMemcpyHostToDevice
);
return
;
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
float
)
*
scale_size_
);
}
cudaMemcpy
(
scale_gpu_
,
scale_gpu_
=
ptr
->
scale_gpu_
;
scale_
.
data
(),
bias_gpu_
=
ptr
->
bias_gpu_
;
scale_size_
*
sizeof
(
float
),
}
cudaMemcpyHostToDevice
);
template
<
typename
T
>
int
SkipLayerNormPluginDynamicImpl
<
T
>::
initialize
()
{
if
(
is_initialized_
)
{
return
0
;
}
if
(
bias_
)
{
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
T
)
*
bias_size_
);
cudaMemcpy
(
bias_gpu_
,
bias_
,
bias_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
if
(
scale_
)
{
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
T
)
*
scale_size_
);
cudaMemcpy
(
scale_gpu_
,
scale_
,
scale_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
is_initialized_
=
true
;
return
0
;
return
0
;
}
}
void
SkipLayerNormPluginDynamic
::
terminate
()
TRT_NOEXCEPT
{
template
<
typename
T
>
void
SkipLayerNormPluginDynamicImpl
<
T
>::
terminate
()
{
if
(
bias_gpu_
)
{
if
(
bias_gpu_
)
{
cudaFree
(
bias_gpu_
);
cudaFree
(
bias_gpu_
);
bias_gpu_
=
nullptr
;
bias_gpu_
=
nullptr
;
}
}
if
(
scale_gpu_
)
{
if
(
scale_gpu_
)
{
cudaFree
(
scale_gpu_
);
cudaFree
(
scale_gpu_
);
scale_gpu_
=
nullptr
;
scale_gpu_
=
nullptr
;
}
}
}
}
int
SkipLayerNormPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
impl_
->
initialize
();
return
0
;
}
void
SkipLayerNormPluginDynamic
::
terminate
()
TRT_NOEXCEPT
{
impl_
->
terminate
();
}
nvinfer1
::
DimsExprs
SkipLayerNormPluginDynamic
::
getOutputDimensions
(
nvinfer1
::
DimsExprs
SkipLayerNormPluginDynamic
::
getOutputDimensions
(
int
output_index
,
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
const
nvinfer1
::
DimsExprs
*
inputs
,
...
@@ -73,6 +103,12 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
...
@@ -73,6 +103,12 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
in_out
,
in_out
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
"The input of swish plugin shoule not be nullptr."
));
PADDLE_ENFORCE_EQ
(
nb_outputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The SkipLayerNorm's output should be one"
"but it's (%d) outputs."
,
nb_outputs
));
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
pos
,
pos
,
...
@@ -82,30 +118,27 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
...
@@ -82,30 +118,27 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
pos
,
pos
,
nb_inputs
+
nb_outputs
));
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
return
(
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
desc
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
#else
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
(
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
desc
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#endif
#endif
}
else
{
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
(
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
desc
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
if
(
pos
==
1
)
{
if
(
pos
==
1
)
{
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
return
desc
.
type
==
prev
.
type
&&
desc
.
format
==
prev
.
format
;
}
}
// output
// output
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
return
desc
.
type
==
prev
.
type
&&
desc
.
format
==
prev
.
format
;
}
}
nvinfer1
::
DataType
SkipLayerNormPluginDynamic
::
getOutputDataType
(
nvinfer1
::
DataType
SkipLayerNormPluginDynamic
::
getOutputDataType
(
...
@@ -115,7 +148,7 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
...
@@ -115,7 +148,7 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
PADDLE_ENFORCE_EQ
(
index
,
PADDLE_ENFORCE_EQ
(
index
,
0
,
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The SkipLayerNorm Plugin only has one
in
put, so the "
"The SkipLayerNorm Plugin only has one
out
put, so the "
"index value should be 0, but get %d."
,
"index value should be 0, but get %d."
,
index
));
index
));
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
||
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
||
...
@@ -126,7 +159,8 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
...
@@ -126,7 +159,8 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
return
input_types
[
0
];
return
input_types
[
0
];
}
}
int
SkipLayerNormPluginDynamic
::
enqueue
(
template
<
typename
T
>
int
SkipLayerNormPluginDynamicImpl
<
T
>::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
const
void
*
const
*
inputs
,
...
@@ -138,51 +172,45 @@ int SkipLayerNormPluginDynamic::enqueue(
...
@@ -138,51 +172,45 @@ int SkipLayerNormPluginDynamic::enqueue(
int
hidden
=
input_dims
.
d
[
2
];
int
hidden
=
input_dims
.
d
[
2
];
auto
input_type
=
input_desc
[
0
].
type
;
auto
input_type
=
input_desc
[
0
].
type
;
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipLayerNorm-->fp32"
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
const
float
*
input1
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
PADDLE_ENFORCE_EQ
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
,
const
float
*
input2
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
true
,
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
platform
::
errors
::
InvalidArgument
(
operators
::
math
::
SkipLayerNormFunctor
<
float
>
skip_layer_norm_func
;
"The SkipLayernorm Plugin only support fp32 input."
));
skip_layer_norm_func
(
num
,
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
hidden
,
PADDLE_ENFORCE_EQ
(
input_type
==
nvinfer1
::
DataType
::
kHALF
,
input1
,
true
,
input2
,
platform
::
errors
::
InvalidArgument
(
scale_gpu_
,
"The SkipLayernorm Plugin only support fp16 input."
));
bias_gpu_
,
output
,
eps_
,
stream
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipLayerNorm-->fp16"
;
const
half
*
input1
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
half
*
input2
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
operators
::
math
::
SkipLayerNormFunctor
<
half
>
skip_layer_norm_func
;
skip_layer_norm_func
(
num
,
hidden
,
input1
,
input2
,
scale_gpu_
,
bias_gpu_
,
output
,
static_cast
<
half
>
(
eps_
),
stream
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"
));
#endif
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The SkipLayerNorm TRT Plugin's input type should be float or half."
));
"Unsupport data type, the out type of SkipLayernorm should be "
"float or half."
));
}
}
auto
*
output_d
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
const
T
*
input1
=
reinterpret_cast
<
const
T
*>
(
inputs
[
0
]);
const
T
*
input2
=
reinterpret_cast
<
const
T
*>
(
inputs
[
1
]);
auto
*
output
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
operators
::
math
::
SkipLayerNormFunctor
<
T
>
skip_layer_norm_func
;
skip_layer_norm_func
(
num
,
hidden
,
input1
,
input2
,
scale_gpu_
,
bias_gpu_
,
output
,
eps_
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
int
SkipLayerNormPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
impl_
->
enqueue
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
#endif
#endif
}
// namespace plugin
}
// namespace plugin
...
...
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
浏览文件 @
ac0553a0
...
@@ -15,11 +15,13 @@
...
@@ -15,11 +15,13 @@
#pragma once
#pragma once
#include <algorithm>
#include <algorithm>
#include <cstddef>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/common/data_type.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -27,36 +29,155 @@ namespace tensorrt {
...
@@ -27,36 +29,155 @@ namespace tensorrt {
namespace
plugin
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
class
SkipLayerNormPluginDynamicImplBase
{
public:
SkipLayerNormPluginDynamicImplBase
()
{}
virtual
~
SkipLayerNormPluginDynamicImplBase
()
{}
virtual
int
initialize
()
=
0
;
virtual
void
terminate
()
=
0
;
virtual
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
virtual
void
shareGPUData
(
const
SkipLayerNormPluginDynamicImplBase
*
anthor
)
=
0
;
};
template
<
typename
T
>
class
SkipLayerNormPluginDynamicImpl
:
public
SkipLayerNormPluginDynamicImplBase
{
public:
explicit
SkipLayerNormPluginDynamicImpl
(
T
*
bias
,
T
*
scale
,
int
bias_size
,
int
scale_size
,
const
float
eps
)
:
bias_
(
bias
),
scale_
(
scale
),
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
eps_
(
eps
)
{}
~
SkipLayerNormPluginDynamicImpl
()
{}
int
initialize
();
void
terminate
();
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
;
void
shareGPUData
(
const
SkipLayerNormPluginDynamicImplBase
*
anthor
);
private:
T
*
bias_
{
nullptr
};
T
*
scale_
{
nullptr
};
// data on devices
T
*
bias_gpu_
{
nullptr
};
T
*
scale_gpu_
{
nullptr
};
int
bias_size_
;
int
scale_size_
;
float
eps_
;
bool
is_initialized_
{
false
};
};
class
SkipLayerNormPluginDynamic
:
public
DynamicPluginTensorRT
{
class
SkipLayerNormPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
public:
explicit
SkipLayerNormPluginDynamic
(
const
float
*
bias
,
explicit
SkipLayerNormPluginDynamic
(
void
*
bias
,
const
float
*
scale
,
void
*
scale
,
int
bias_size
,
int
bias_size
,
int
scale_size
,
int
scale_size
,
const
float
eps
,
float
eps
,
bool
with_fp16
)
bool
with_fp16
)
:
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
eps_
(
eps
)
{
:
bias_
(
bias
),
scale_
(
scale
),
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
eps_
(
eps
),
own_host_buff_
(
false
)
{
with_fp16_
=
with_fp16
;
with_fp16_
=
with_fp16
;
bias_
.
resize
(
bias_size
);
if
(
with_fp16_
)
{
scale_
.
resize
(
scale_size
);
#ifdef TRT_PLUGIN_FP16_AVALIABLE
std
::
copy
(
bias
,
bias
+
bias_size
,
bias_
.
data
());
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipLayerNorm-->fp16"
;
std
::
copy
(
scale
,
scale
+
scale_size
,
scale_
.
data
());
instantiateImpl
<
half
>
();
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "
));
#endif
}
else
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipLayerNorm-->fp32"
;
instantiateImpl
<
float
>
();
}
}
}
SkipLayerNormPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{
SkipLayerNormPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_
);
:
own_host_buff_
(
true
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_
);
// the first var is with_fp16, we will use it.
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
if
(
with_fp16_
)
{
if
(
bias_size_
)
{
bias_
=
new
half
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
half
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
half
);
serial_length
-=
bias_size_
*
sizeof
(
half
);
if
(
scale_size_
)
{
scale_
=
new
half
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
half
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
half
);
serial_length
-=
scale_size_
*
sizeof
(
half
);
}
else
{
if
(
bias_size_
)
{
bias_
=
new
float
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
float
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
float
);
serial_length
-=
bias_size_
*
sizeof
(
float
);
if
(
scale_size_
)
{
scale_
=
new
float
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
float
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
float
);
serial_length
-=
scale_size_
*
sizeof
(
float
);
}
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
instantiateImpl
<
half
>
();
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "
));
#endif
}
else
{
instantiateImpl
<
float
>
();
}
}
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
ptr
=
new
SkipLayerNormPluginDynamic
(
auto
ptr
=
new
SkipLayerNormPluginDynamic
(
bias_
.
data
(),
scale_
.
data
(),
bias_size_
,
scale_size_
,
eps_
,
with_fp16_
);
bias_
,
scale_
,
bias_size_
,
scale_size_
,
eps_
,
with_fp16_
);
ptr
->
bias_gpu_
=
bias_gpu_
;
ptr
->
shareGPUData
(
this
);
ptr
->
scale_gpu_
=
scale_gpu_
;
return
ptr
;
return
ptr
;
}
}
...
@@ -65,20 +186,48 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -65,20 +186,48 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
}
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
size_t
ser_size
=
SerializedSize
(
bias_
)
+
SerializedSize
(
scale_
)
+
size_t
sum_num
=
0
;
SerializedSize
(
bias_size_
)
+
SerializedSize
(
scale_size_
)
+
sum_num
+=
SerializedSize
(
with_fp16_
);
SerializedSize
(
eps_
)
+
SerializedSize
(
with_fp16_
);
return
ser_size
;
if
(
with_fp16_
)
{
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
half
);
}
else
{
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
float
);
}
sum_num
+=
SerializedSize
(
bias_size_
);
sum_num
+=
SerializedSize
(
scale_size_
);
sum_num
+=
SerializedSize
(
eps_
);
return
sum_num
;
}
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
SerializeValue
(
&
buffer
,
bias_
)
;
// the first var is for with_fp16, we will use it later
;
SerializeValue
(
&
buffer
,
scale
_
);
SerializeValue
(
&
buffer
,
with_fp16
_
);
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
if
(
with_fp16_
)
{
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
scale_
)[
i
]);
}
}
else
{
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
scale_
)[
i
]);
}
}
}
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
...
@@ -115,20 +264,43 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
...
@@ -115,20 +264,43 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
int
nb_inputs
)
const
int
nb_inputs
)
const
TRT_NOEXCEPT
override
;
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
void
destroy
()
TRT_NOEXCEPT
override
{
void
terminate
()
TRT_NOEXCEPT
override
;
if
(
own_host_buff_
)
{
if
(
with_fp16_
)
{
delete
[]
reinterpret_cast
<
half
*>
(
bias_
);
delete
[]
reinterpret_cast
<
half
*>
(
scale_
);
}
else
{
delete
[]
reinterpret_cast
<
float
*>
(
bias_
);
delete
[]
reinterpret_cast
<
float
*>
(
scale_
);
}
}
delete
impl_
;
delete
this
;
}
private:
private:
std
::
vector
<
float
>
bias_
;
void
*
bias_
{
nullptr
};
std
::
vector
<
float
>
scale_
;
void
*
scale_
{
nullptr
};
float
*
bias_gpu_
{
nullptr
};
float
*
scale_gpu_
{
nullptr
};
int
bias_size_
;
int
bias_size_
;
int
scale_size_
;
int
scale_size_
;
float
eps_
;
float
eps_
;
bool
own_host_buff_
{
false
};
SkipLayerNormPluginDynamicImplBase
*
impl_
{
nullptr
};
void
shareGPUData
(
const
SkipLayerNormPluginDynamic
*
anthor
)
{
impl_
->
shareGPUData
(
anthor
->
impl_
);
}
template
<
typename
U
>
void
instantiateImpl
()
{
impl_
=
new
SkipLayerNormPluginDynamicImpl
<
U
>
(
reinterpret_cast
<
U
*>
(
bias_
),
reinterpret_cast
<
U
*>
(
scale_
),
bias_size_
,
scale_size_
,
eps_
);
}
};
};
class
SkipLayerNormPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
class
SkipLayerNormPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
...
@@ -154,8 +326,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
...
@@ -154,8 +326,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
const
void
*
serial_data
,
const
void
*
serial_data
,
size_t
serial_length
)
size_t
serial_length
)
TRT_NOEXCEPT
override
{
TRT_NOEXCEPT
override
{
auto
plugin
=
new
SkipLayerNormPluginDynamic
(
serial_data
,
serial_length
);
return
new
SkipLayerNormPluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
...
@@ -173,6 +344,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
...
@@ -173,6 +344,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
};
REGISTER_TRT_PLUGIN_V2
(
SkipLayerNormPluginDynamicCreator
);
REGISTER_TRT_PLUGIN_V2
(
SkipLayerNormPluginDynamicCreator
);
#endif
#endif
}
// namespace plugin
}
// namespace plugin
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h
浏览文件 @
ac0553a0
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <cuda_fp16.h>
#include <cstring>
#include <cstring>
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
...
@@ -46,10 +47,11 @@ template <typename T, class Enable = void>
...
@@ -46,10 +47,11 @@ template <typename T, class Enable = void>
struct
Serializer
{};
struct
Serializer
{};
template
<
typename
T
>
template
<
typename
T
>
struct
Serializer
<
T
,
struct
Serializer
<
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
T
,
std
::
is_enum
<
T
>::
value
||
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
||
std
::
is_same
<
T
,
half
>::
value
>::
type
>
{
static
size_t
SerializedSize
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
size_t
SerializedSize
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
void
Serialize
(
void
**
buffer
,
T
const
&
value
)
{
static
void
Serialize
(
void
**
buffer
,
T
const
&
value
)
{
...
@@ -86,10 +88,11 @@ struct Serializer<const char*> {
...
@@ -86,10 +88,11 @@ struct Serializer<const char*> {
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
Serializer
<
std
::
vector
<
T
>
,
struct
Serializer
<
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
vector
<
T
>
,
std
::
is_enum
<
T
>::
value
||
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
||
std
::
is_same
<
T
,
half
>::
value
>::
type
>
{
static
size_t
SerializedSize
(
std
::
vector
<
T
>
const
&
value
)
{
static
size_t
SerializedSize
(
std
::
vector
<
T
>
const
&
value
)
{
return
sizeof
(
value
.
size
())
+
value
.
size
()
*
sizeof
(
T
);
return
sizeof
(
value
.
size
())
+
value
.
size
()
*
sizeof
(
T
);
}
}
...
...
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h
浏览文件 @
ac0553a0
...
@@ -98,8 +98,9 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) {
...
@@ -98,8 +98,9 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) {
std
::
string
model_dir
=
FLAGS_infer_model
;
std
::
string
model_dir
=
FLAGS_infer_model
;
// Delete serialization cache to perform serialization first rather than
// Delete serialization cache to perform serialization first rather than
// deserialization.
// deserialization.
std
::
string
opt_cache_dir
=
FLAGS_infer_model
+
"/
_
opt_cache"
;
std
::
string
opt_cache_dir
=
FLAGS_infer_model
+
"/opt_cache"
;
delete_cache_files
(
opt_cache_dir
);
delete_cache_files
(
opt_cache_dir
);
config
.
SetOptimCacheDir
(
opt_cache_dir
);
SetConfig
(
&
config
,
model_dir
,
true
/* use_gpu */
);
SetConfig
(
&
config
,
model_dir
,
true
/* use_gpu */
);
...
...
paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu
浏览文件 @
ac0553a0
...
@@ -15,11 +15,14 @@
...
@@ -15,11 +15,14 @@
#include <paddle/fluid/platform/device_context.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -99,19 +102,37 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
...
@@ -99,19 +102,37 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
auto
*
output_d
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
output_d
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
eps
=
context
.
Attr
<
float
>
(
"epsilon"
);
float
eps
=
context
.
Attr
<
float
>
(
"epsilon"
);
int
shared_bytes
=
input_num
*
sizeof
(
int64_t
);
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
math
::
EmbEltwiseLayerNormFunctor
<
T
>
emb_eltwise_layernorm_func
;
const
half
*
scale_new
=
reinterpret_cast
<
const
half
*>
(
scale_d
);
emb_eltwise_layernorm_func
(
batch
,
const
half
*
bias_new
=
reinterpret_cast
<
const
half
*>
(
bias_d
);
seq_len
,
half
*
output_new
=
reinterpret_cast
<
half
*>
(
output_d
);
hidden
,
in_ids_d
,
math
::
EmbEltwiseLayerNormFunctor
<
half
>
emb_eltwise_layernorm_func
;
scale_d
,
emb_eltwise_layernorm_func
(
batch
,
bias_d
,
seq_len
,
in_embs_d
,
hidden
,
output_d
,
in_ids_d
,
eps
,
scale_new
,
input_num
,
bias_new
,
device_ctx
.
stream
());
in_embs_d
,
output_new
,
eps
,
input_num
,
device_ctx
.
stream
());
}
else
{
math
::
EmbEltwiseLayerNormFunctor
<
T
>
emb_eltwise_layernorm_func
;
emb_eltwise_layernorm_func
(
batch
,
seq_len
,
hidden
,
in_ids_d
,
scale_d
,
bias_d
,
in_embs_d
,
output_d
,
eps
,
input_num
,
device_ctx
.
stream
());
}
}
}
};
};
...
@@ -119,6 +140,14 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
...
@@ -119,6 +140,14 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL
(
fused_embedding_eltwise_layernorm
,
ops
::
EmbeddingEltWiseLayerNormKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
EmbeddingEltWiseLayerNormKernel
<
phi
::
GPUContext
,
paddle
::
platform
::
float16
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
fused_embedding_eltwise_layernorm
,
fused_embedding_eltwise_layernorm
,
ops
::
EmbeddingEltWiseLayerNormKernel
<
phi
::
GPUContext
,
float
>
);
ops
::
EmbeddingEltWiseLayerNormKernel
<
phi
::
GPUContext
,
float
>
);
#endif
paddle/fluid/operators/fused/skip_layernorm_op.cu
浏览文件 @
ac0553a0
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include <paddle/fluid/platform/device_context.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
...
@@ -53,15 +54,34 @@ class SkipLayerNormKernel : public framework::OpKernel<T> {
...
@@ -53,15 +54,34 @@ class SkipLayerNormKernel : public framework::OpKernel<T> {
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
operators
::
math
::
SkipLayerNormFunctor
<
T
>
skip_layer_norm_func
;
operators
::
math
::
SkipLayerNormFunctor
<
T
>
skip_layer_norm_func
;
skip_layer_norm_func
(
num
,
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
hidden
,
const
half
*
X_new
=
reinterpret_cast
<
const
half
*>
(
X_d
);
X_d
,
const
half
*
Y_new
=
reinterpret_cast
<
const
half
*>
(
Y_d
);
Y_d
,
const
half
*
scale_new
=
reinterpret_cast
<
const
half
*>
(
scale_d
);
scale_d
,
const
half
*
bias_new
=
reinterpret_cast
<
const
half
*>
(
bias_d
);
bias_d
,
half
*
output_new
=
reinterpret_cast
<
half
*>
(
output_d
);
output_d
,
operators
::
math
::
SkipLayerNormFunctor
<
half
>
skip_layer_norm_func
;
epsilon
,
skip_layer_norm_func
(
num
,
device_ctx
.
stream
());
hidden
,
X_new
,
Y_new
,
scale_new
,
bias_new
,
output_new
,
epsilon
,
device_ctx
.
stream
());
}
else
{
operators
::
math
::
SkipLayerNormFunctor
<
T
>
skip_layer_norm_func
;
skip_layer_norm_func
(
num
,
hidden
,
X_d
,
Y_d
,
scale_d
,
bias_d
,
output_d
,
epsilon
,
device_ctx
.
stream
());
}
}
}
};
};
...
@@ -69,5 +89,13 @@ class SkipLayerNormKernel : public framework::OpKernel<T> {
...
@@ -69,5 +89,13 @@ class SkipLayerNormKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL
(
skip_layernorm
,
ops
::
SkipLayerNormKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
SkipLayerNormKernel
<
phi
::
GPUContext
,
paddle
::
platform
::
float16
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
skip_layernorm
,
REGISTER_OP_CUDA_KERNEL
(
skip_layernorm
,
ops
::
SkipLayerNormKernel
<
phi
::
GPUContext
,
float
>
);
ops
::
SkipLayerNormKernel
<
phi
::
GPUContext
,
float
>
);
#endif
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
ac0553a0
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
...
@@ -42,8 +43,8 @@ __device__ inline void LayerNormSmall(T val,
...
@@ -42,8 +43,8 @@ __device__ inline void LayerNormSmall(T val,
const
phi
::
funcs
::
kvp
<
T
>
&
thread_data
,
const
phi
::
funcs
::
kvp
<
T
>
&
thread_data
,
const
int
ld
,
const
int
ld
,
const
int
idx
,
const
int
idx
,
const
float
*
bias
,
const
T
*
bias
,
const
float
*
scale
,
const
T
*
scale
,
T
*
output
,
T
*
output
,
T
eps
)
{
T
eps
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
phi
::
funcs
::
kvp
<
T
>
,
TPB
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
phi
::
funcs
::
kvp
<
T
>
,
TPB
>
;
...
@@ -70,8 +71,8 @@ template <typename T, int TPB>
...
@@ -70,8 +71,8 @@ template <typename T, int TPB>
__device__
inline
void
LayerNorm
(
const
phi
::
funcs
::
kvp
<
T
>
&
thread_data
,
__device__
inline
void
LayerNorm
(
const
phi
::
funcs
::
kvp
<
T
>
&
thread_data
,
const
int
ld
,
const
int
ld
,
const
int
offset
,
const
int
offset
,
const
float
*
bias
,
const
T
*
bias
,
const
float
*
scale
,
const
T
*
scale
,
T
*
output
,
T
*
output
,
T
eps
)
{
T
eps
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
phi
::
funcs
::
kvp
<
T
>
,
TPB
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
phi
::
funcs
::
kvp
<
T
>
,
TPB
>
;
...
@@ -100,8 +101,8 @@ template <typename T, typename T2, int TPB>
...
@@ -100,8 +101,8 @@ template <typename T, typename T2, int TPB>
__device__
inline
void
LayerNorm2
(
const
phi
::
funcs
::
kvp
<
T
>
&
thread_data
,
__device__
inline
void
LayerNorm2
(
const
phi
::
funcs
::
kvp
<
T
>
&
thread_data
,
const
int
ld
,
const
int
ld
,
const
int
offset
,
const
int
offset
,
const
float
2
*
bias
,
const
T
2
*
bias
,
const
float
2
*
scale
,
const
T
2
*
scale
,
T2
*
output
,
T2
*
output
,
T
eps
)
{
T
eps
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
phi
::
funcs
::
kvp
<
T
>
,
TPB
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
phi
::
funcs
::
kvp
<
T
>
,
TPB
>
;
...
@@ -120,8 +121,8 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
...
@@ -120,8 +121,8 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
for
(
int
i
=
threadIdx
.
x
;
i
<
ld
;
i
+=
TPB
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
ld
;
i
+=
TPB
)
{
const
int
idx
=
offset
+
i
;
const
int
idx
=
offset
+
i
;
T2
val
=
output
[
idx
];
T2
val
=
output
[
idx
];
const
float
2
g
=
scale
[
i
];
const
T
2
g
=
scale
[
i
];
const
float
2
b
=
bias
[
i
];
const
T
2
b
=
bias
[
i
];
val
.
x
=
T
(
g
.
x
)
*
(
val
.
x
-
mu
)
*
rsigma
+
T
(
b
.
x
);
val
.
x
=
T
(
g
.
x
)
*
(
val
.
x
-
mu
)
*
rsigma
+
T
(
b
.
x
);
val
.
y
=
T
(
g
.
y
)
*
(
val
.
y
-
mu
)
*
rsigma
+
T
(
b
.
y
);
val
.
y
=
T
(
g
.
y
)
*
(
val
.
y
-
mu
)
*
rsigma
+
T
(
b
.
y
);
output
[
idx
]
=
val
;
output
[
idx
]
=
val
;
...
@@ -131,11 +132,11 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
...
@@ -131,11 +132,11 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
template
<
typename
T
,
unsigned
TPB
>
template
<
typename
T
,
unsigned
TPB
>
__global__
void
EmbEltwiseLayernormKernel
(
int
hidden
,
__global__
void
EmbEltwiseLayernormKernel
(
int
hidden
,
const
int64_t
*
ids
,
const
int64_t
*
ids
,
const
float
*
scale
,
const
T
*
scale
,
const
float
*
bias
,
const
T
*
bias
,
const
int64_t
*
embs
,
const
int64_t
*
embs
,
T
*
output
,
T
*
output
,
float
eps
,
T
eps
,
int
input_num
)
{
int
input_num
)
{
cub
::
Sum
pair_sum
;
cub
::
Sum
pair_sum
;
// blockIdx.x: position in the sequence
// blockIdx.x: position in the sequence
...
@@ -179,11 +180,11 @@ __global__ void EmbEltwiseLayernormKernel(int hidden,
...
@@ -179,11 +180,11 @@ __global__ void EmbEltwiseLayernormKernel(int hidden,
template
<
>
template
<
>
__global__
void
EmbEltwiseLayernormKernel
<
half
,
256
>
(
int
hidden
,
__global__
void
EmbEltwiseLayernormKernel
<
half
,
256
>
(
int
hidden
,
const
int64_t
*
ids
,
const
int64_t
*
ids
,
const
float
*
scale
,
const
half
*
scale
,
const
float
*
bias
,
const
half
*
bias
,
const
int64_t
*
embs
,
const
int64_t
*
embs
,
half
*
output
,
half
*
output
,
float
eps
,
half
eps
,
int
input_num
)
{
int
input_num
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
cub
::
Sum
pair_sum
;
cub
::
Sum
pair_sum
;
...
@@ -231,8 +232,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()(int batch,
...
@@ -231,8 +232,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()(int batch,
int
seq_len
,
int
seq_len
,
int
hidden
,
int
hidden
,
const
int64_t
*
ids
,
const
int64_t
*
ids
,
const
float
*
scale
,
const
T
*
scale
,
const
float
*
bias
,
const
T
*
bias
,
const
int64_t
*
embs
,
const
int64_t
*
embs
,
T
*
output
,
T
*
output
,
float
eps
,
float
eps
,
...
@@ -720,9 +721,9 @@ __global__ void SkipLayerNormSmallKernel(int num,
...
@@ -720,9 +721,9 @@ __global__ void SkipLayerNormSmallKernel(int num,
const
T
*
input1
,
const
T
*
input1
,
const
T
*
input2
,
const
T
*
input2
,
T
*
output
,
T
*
output
,
const
float
*
scale
,
const
T
*
scale
,
const
float
*
bias
,
const
T
*
bias
,
float
eps
)
{
T
eps
)
{
const
T
rld
=
T
(
1
)
/
T
(
hidden
);
const
T
rld
=
T
(
1
)
/
T
(
hidden
);
const
int
offset
=
blockIdx
.
x
*
hidden
;
const
int
offset
=
blockIdx
.
x
*
hidden
;
cub
::
Sum
pair_sum
;
cub
::
Sum
pair_sum
;
...
@@ -747,9 +748,9 @@ __global__ void SkipLayerNormSmallKernel<half, 32>(int num,
...
@@ -747,9 +748,9 @@ __global__ void SkipLayerNormSmallKernel<half, 32>(int num,
const
half
*
input1
,
const
half
*
input1
,
const
half
*
input2
,
const
half
*
input2
,
half
*
output
,
half
*
output
,
const
float
*
scale
,
const
half
*
scale
,
const
float
*
bias
,
const
half
*
bias
,
float
eps
)
{
half
eps
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
int
offset
=
blockIdx
.
x
*
hidden
;
const
int
offset
=
blockIdx
.
x
*
hidden
;
...
@@ -774,9 +775,9 @@ __global__ void SkipLayerNormSmallKernel<half, 128>(int num,
...
@@ -774,9 +775,9 @@ __global__ void SkipLayerNormSmallKernel<half, 128>(int num,
const
half
*
input1
,
const
half
*
input1
,
const
half
*
input2
,
const
half
*
input2
,
half
*
output
,
half
*
output
,
const
float
*
scale
,
const
half
*
scale
,
const
float
*
bias
,
const
half
*
bias
,
float
eps
)
{
half
eps
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
int
offset
=
blockIdx
.
x
*
hidden
;
const
int
offset
=
blockIdx
.
x
*
hidden
;
...
@@ -801,9 +802,9 @@ __global__ void SkipLayerNormSmallKernel<half, 384>(int num,
...
@@ -801,9 +802,9 @@ __global__ void SkipLayerNormSmallKernel<half, 384>(int num,
const
half
*
input1
,
const
half
*
input1
,
const
half
*
input2
,
const
half
*
input2
,
half
*
output
,
half
*
output
,
const
float
*
scale
,
const
half
*
scale
,
const
float
*
bias
,
const
half
*
bias
,
float
eps
)
{
half
eps
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
int
offset
=
blockIdx
.
x
*
hidden
;
const
int
offset
=
blockIdx
.
x
*
hidden
;
...
@@ -829,9 +830,9 @@ __global__ void SkipLayerNormKernel(int num,
...
@@ -829,9 +830,9 @@ __global__ void SkipLayerNormKernel(int num,
const
T
*
input1
,
const
T
*
input1
,
const
T
*
input2
,
const
T
*
input2
,
T
*
output
,
T
*
output
,
const
float
*
scale
,
const
T
*
scale
,
const
float
*
bias
,
const
T
*
bias
,
float
eps
)
{
T
eps
)
{
const
T
rld
=
T
(
1
)
/
T
(
hidden
);
const
T
rld
=
T
(
1
)
/
T
(
hidden
);
const
int
offset
=
blockIdx
.
x
*
hidden
;
const
int
offset
=
blockIdx
.
x
*
hidden
;
cub
::
Sum
pair_sum
;
cub
::
Sum
pair_sum
;
...
@@ -856,9 +857,9 @@ __global__ void SkipLayerNormKernel<half, 256>(int num,
...
@@ -856,9 +857,9 @@ __global__ void SkipLayerNormKernel<half, 256>(int num,
const
half
*
input1
,
const
half
*
input1
,
const
half
*
input2
,
const
half
*
input2
,
half
*
output
,
half
*
output
,
const
float
*
scale
,
const
half
*
scale
,
const
float
*
bias
,
const
half
*
bias
,
float
eps
)
{
half
eps
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
half
rld
=
half
(
1
)
/
half
(
hidden
);
const
int
offset
=
blockIdx
.
x
*
hidden
;
const
int
offset
=
blockIdx
.
x
*
hidden
;
...
@@ -884,8 +885,8 @@ __global__ void SkipLayerNormKernel2(int num,
...
@@ -884,8 +885,8 @@ __global__ void SkipLayerNormKernel2(int num,
const
T2
*
input1
,
const
T2
*
input1
,
const
T2
*
input2
,
const
T2
*
input2
,
T2
*
output
,
T2
*
output
,
const
float
2
*
scale
,
const
T
2
*
scale
,
const
float
2
*
bias
,
const
T
2
*
bias
,
float
eps
)
{
float
eps
)
{
const
T
rld
=
T
(
0.5
f
/
hidden
);
// because hidden is hidden/2
const
T
rld
=
T
(
0.5
f
/
hidden
);
// because hidden is hidden/2
const
int
offset
=
blockIdx
.
x
*
hidden
;
const
int
offset
=
blockIdx
.
x
*
hidden
;
...
@@ -912,8 +913,8 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>(int num,
...
@@ -912,8 +913,8 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>(int num,
const
half2
*
input1
,
const
half2
*
input1
,
const
half2
*
input2
,
const
half2
*
input2
,
half2
*
output
,
half2
*
output
,
const
float
2
*
scale
,
const
half
2
*
scale
,
const
float
2
*
bias
,
const
half
2
*
bias
,
float
eps
)
{
float
eps
)
{
// operator "+" of half only suppotted after cuda version 10.0
// operator "+" of half only suppotted after cuda version 10.0
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000
...
@@ -942,10 +943,10 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
...
@@ -942,10 +943,10 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
const
int
hidden
,
const
int
hidden
,
const
T
*
input1
,
const
T
*
input1
,
const
T
*
input2
,
const
T
*
input2
,
const
float
*
scale
,
const
T
*
scale
,
const
float
*
bias
,
const
T
*
bias
,
T
*
output
,
T
*
output
,
T
eps
,
float
eps
,
gpuStream_t
stream
)
{
gpuStream_t
stream
)
{
int
block
=
num
/
hidden
;
int
block
=
num
/
hidden
;
if
(
hidden
<=
32
)
{
if
(
hidden
<=
32
)
{
...
@@ -984,8 +985,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
...
@@ -984,8 +985,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
reinterpret_cast
<
const
__half2
*>
(
input1
),
reinterpret_cast
<
const
__half2
*>
(
input1
),
reinterpret_cast
<
const
__half2
*>
(
input2
),
reinterpret_cast
<
const
__half2
*>
(
input2
),
reinterpret_cast
<
__half2
*>
(
output
),
reinterpret_cast
<
__half2
*>
(
output
),
reinterpret_cast
<
const
float
2
*>
(
scale
),
reinterpret_cast
<
const
__half
2
*>
(
scale
),
reinterpret_cast
<
const
float
2
*>
(
bias
),
reinterpret_cast
<
const
__half
2
*>
(
bias
),
eps
);
eps
);
#endif
#endif
}
else
{
}
else
{
...
...
paddle/fluid/operators/math/bert_encoder_functor.h
浏览文件 @
ac0553a0
...
@@ -68,8 +68,8 @@ class EmbEltwiseLayerNormFunctor {
...
@@ -68,8 +68,8 @@ class EmbEltwiseLayerNormFunctor {
int
seq_len
,
int
seq_len
,
int
hidden
,
int
hidden
,
const
int64_t
*
ids
,
const
int64_t
*
ids
,
const
float
*
scale
,
const
T
*
scale
,
const
float
*
bias
,
const
T
*
bias
,
const
int64_t
*
embs
,
const
int64_t
*
embs
,
T
*
output
,
T
*
output
,
float
eps
,
float
eps
,
...
@@ -125,10 +125,10 @@ class SkipLayerNormFunctor {
...
@@ -125,10 +125,10 @@ class SkipLayerNormFunctor {
const
int
hidden
,
const
int
hidden
,
const
T
*
input1
,
const
T
*
input1
,
const
T
*
input2
,
const
T
*
input2
,
const
float
*
scale
,
const
T
*
scale
,
const
float
*
bias
,
const
T
*
bias
,
T
*
output
,
T
*
output
,
T
eps
,
float
eps
,
gpuStream_t
stream
);
gpuStream_t
stream
);
};
};
#endif
#endif
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
ac0553a0
...
@@ -562,6 +562,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -562,6 +562,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
}
}
runtime_batch
=
t_shape
[
0
];
runtime_batch
=
t_shape
[
0
];
VLOG
(
1
)
<<
"trt input ["
<<
x
<<
"] dtype is "
<<
t
.
dtype
();
VLOG
(
1
)
<<
"trt input ["
<<
x
<<
"] dtype is "
<<
t
.
dtype
();
auto
indata_type
=
inference
::
tensorrt
::
PhiType2NvType
(
t
.
dtype
());
auto
indata_type
=
inference
::
tensorrt
::
PhiType2NvType
(
t
.
dtype
());
auto
intrt_index
=
engine
->
engine
()
->
getBindingIndex
(
x
.
c_str
());
auto
intrt_index
=
engine
->
engine
()
->
getBindingIndex
(
x
.
c_str
());
auto
intrt_type
=
engine
->
engine
()
->
getBindingDataType
(
intrt_index
);
auto
intrt_type
=
engine
->
engine
()
->
getBindingDataType
(
intrt_index
);
...
@@ -570,6 +571,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -570,6 +571,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The TRT Engine OP's input type should equal "
"The TRT Engine OP's input type should equal "
"to the input data type"
));
"to the input data type"
));
auto
type
=
framework
::
TransToProtoVarType
(
t
.
dtype
());
auto
type
=
framework
::
TransToProtoVarType
(
t
.
dtype
());
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录