Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2a9c590b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2a9c590b
编写于
10月 08, 2022
作者:
W
Wangzheee
提交者:
GitHub
10月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] add lookup_table op_convert, add lookup_table plugin (#46613)
* add lookup_table op_convert, add lookup_table plugin
上级
19746835
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
677 addition
and
74 deletion
+677
-74
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+2
-1
paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc
...luid/inference/tensorrt/convert/fused_lookup_tables_op.cc
+123
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+12
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+2
-1
paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h
paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h
+1
-4
paddle/fluid/inference/tensorrt/plugin/common/common.cuh
paddle/fluid/inference/tensorrt/plugin/common/common.cuh
+54
-51
paddle/fluid/inference/tensorrt/plugin/common/plugin.h
paddle/fluid/inference/tensorrt/plugin/common/plugin.h
+2
-3
paddle/fluid/inference/tensorrt/plugin/lookup_table.cu
paddle/fluid/inference/tensorrt/plugin/lookup_table.cu
+346
-0
paddle/fluid/inference/tensorrt/plugin/lookup_table.h
paddle/fluid/inference/tensorrt/plugin/lookup_table.h
+126
-0
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
...nsorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
+1
-4
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
...nsorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
+1
-4
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
...ce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
+0
-1
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
...nce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
+4
-1
python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py
...ests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py
+2
-2
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
2a9c590b
...
...
@@ -2200,6 +2200,7 @@ USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER
(
layernorm_shift_partition
)
USE_TRT_CONVERTER
(
generic_plugin_creater
)
USE_TRT_CONVERTER
(
custom_plugin_creater
)
USE_TRT_CONVERTER
(
lookup_table
)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER
(
sparse_fc
)
USE_TRT_CONVERTER
(
sparse_multihead_matmul
)
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
2a9c590b
...
...
@@ -76,7 +76,8 @@ list(
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc
generic_and_custom_plugin_creater.cc
)
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7 AND NOT WIN32
)
list
(
APPEND CONVERT_FILES emb_eltwise_layernorm.cc
...
...
paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc
0 → 100644
浏览文件 @
2a9c590b
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/plugin/lookup_table.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
FusedLookupTablesOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
if
(
!
engine_
->
with_dynamic_shape
())
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"lookup_table_op must with dynamic shape"
));
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
ids_name
=
op_desc
.
Input
(
"Ids"
).
front
();
auto
w_name
=
op_desc
.
Input
(
"W"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
auto
ids_dims
=
engine_
->
GetITensor
(
ids_name
)
->
getDimensions
();
if
(
ids_dims
.
d
[
ids_dims
.
nbDims
-
1
]
==
1
)
{
nvinfer1
::
Dims
new_ids_dims
;
new_ids_dims
.
nbDims
=
ids_dims
.
nbDims
-
1
;
for
(
int
i
=
0
;
i
<
ids_dims
.
nbDims
-
1
;
i
++
)
{
new_ids_dims
.
d
[
i
]
=
0
;
}
auto
*
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
(
engine_
->
GetITensor
(
ids_name
)));
reshape_layer
->
setReshapeDimensions
(
new_ids_dims
);
reshape_layer
->
setName
(
(
"lookup_table: Shuffle (Output: "
+
output_name
+
")"
).
c_str
());
plugin_inputs
.
push_back
(
reshape_layer
->
getOutput
(
0
));
}
else
{
plugin_inputs
.
push_back
(
engine_
->
GetITensor
(
ids_name
));
}
TensorRTEngine
::
Weight
weight
;
auto
*
w_var
=
scope
.
FindVar
(
w_name
);
auto
*
w_tensor
=
w_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
w_dims
=
w_tensor
->
dims
();
weight
=
engine_
->
GetTrtWeight
(
w_name
,
*
w_tensor
);
auto
weight_size
=
phi
::
product
(
w_dims
);
bool
output_fp16
;
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kFloat32
)
{
output_fp16
=
false
;
}
else
{
output_fp16
=
true
;
}
int32_t
weight_width
=
static_cast
<
int32_t
>
(
w_dims
[
1
]);
std
::
vector
<
nvinfer1
::
PluginField
>
fields
;
fields
.
emplace_back
(
"lookup_table_weight"
,
weight
.
get
().
values
,
GetPluginFieldType
(
weight
.
get
().
type
),
static_cast
<
int32_t
>
(
weight_size
));
fields
.
emplace_back
(
"lookup_table_weight_width"
,
&
weight_width
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
);
fields
.
emplace_back
(
"output_fp16"
,
&
output_fp16
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
);
nvinfer1
::
PluginFieldCollection
*
plugin_ptr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_ptr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
plugin_ptr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_ptr
->
fields
=
fields
.
data
();
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"LookupTablePluginDynamic"
,
"1"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"LookupTablePluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
plugin_layer
->
setName
(
(
"lookup_table: (Output: "
+
output_name
+
")"
).
c_str
());
engine_
->
SetITensor
(
output_name
,
plugin_layer
->
getOutput
(
0
));
free
(
plugin_ptr
);
if
(
enable_int8
)
{
float
out_scale
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
0
),
out_scale
);
}
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
lookup_table
,
FusedLookupTablesOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
2a9c590b
...
...
@@ -2083,6 +2083,14 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if
(
op_type
==
"lookup_table"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"the lookup_table does not support "
"static shape yet"
;
return
false
;
}
}
if
(
use_no_calib_int8
)
{
return
int8_teller_set
.
count
(
op_type
);
}
else
{
...
...
@@ -2201,7 +2209,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"shape"
,
"squeeze2"
,
"unsqueeze2"
,
"layernorm_shift_partition"
};
"layernorm_shift_partition"
,
"lookup_table"
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
"mul"
,
"matmul"
,
...
...
@@ -2312,7 +2321,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"squeeze2"
,
"unsqueeze2"
,
"fused_token_prune"
,
"layernorm_shift_partition"
};
"layernorm_shift_partition"
,
"lookup_table"
};
};
struct
GenericPluginTeller
:
public
Teller
{
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
2a9c590b
...
...
@@ -33,7 +33,8 @@ list(
preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu
generic_plugin.cu
)
generic_plugin.cu
lookup_table.cu
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7 AND NOT WIN32
)
list
(
APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernelMTron.cu
...
...
paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h
浏览文件 @
2a9c590b
...
...
@@ -14,8 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_
#define PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_
#pragma once
#include <cublas_v2.h>
#include <cuda_fp16.h>
...
...
@@ -220,5 +219,3 @@ inline nvinfer1::DataType fieldTypeToDataType(
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
#endif // PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_
paddle/fluid/inference/tensorrt/plugin/common/common.cuh
浏览文件 @
2a9c590b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -13,11 +14,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef COMMON_CUH
#define COMMON_CUH
#pragma once
#include "cublas_v2.h"
#include <cub/cub.cuh>
#include "cublas_v2.h"
using
kv_float
=
cub
::
KeyValuePair
<
float
,
float
>
;
using
kv_half
=
cub
::
KeyValuePair
<
half
,
half
>
;
...
...
@@ -28,22 +28,22 @@ __device__ inline T rsqrt(const T& x);
template
<
>
__device__
inline
float
rsqrt
(
const
float
&
x
)
{
return
rsqrtf
(
x
);
return
rsqrtf
(
x
);
}
__device__
inline
kv_float
operator
+
(
const
kv_float
&
a
,
const
kv_float
&
b
)
{
return
kv_float
(
a
.
key
+
b
.
key
,
a
.
value
+
b
.
value
);
return
kv_float
(
a
.
key
+
b
.
key
,
a
.
value
+
b
.
value
);
}
// Half Operations
__device__
inline
half2
__hadd2_with_fallback
(
const
half2
a
,
const
half2
b
)
{
#if __CUDA_ARCH__ >= 530
return
__hadd2
(
a
,
b
);
return
__hadd2
(
a
,
b
);
#else
float2
out
{};
out
.
x
=
__half2float
(
a
.
x
)
+
__half2float
(
b
.
x
);
out
.
y
=
__half2float
(
a
.
y
)
+
__half2float
(
b
.
y
);
return
__float22half2_rn
(
out
);
float2
out
{};
out
.
x
=
__half2float
(
a
.
x
)
+
__half2float
(
b
.
x
);
out
.
y
=
__half2float
(
a
.
y
)
+
__half2float
(
b
.
y
);
return
__float22half2_rn
(
out
);
#endif
}
#if __CUDA_ARCH__ < 530
...
...
@@ -53,14 +53,14 @@ template <typename T>
__device__
inline
T
operator
*
(
const
T
&
a
,
const
T
&
b
);
template
<
>
__device__
inline
half2
operator
+
(
const
half2
&
a
,
const
half2
&
b
)
{
return
__hadd2_with_fallback
(
a
,
b
);
return
__hadd2_with_fallback
(
a
,
b
);
}
template
<
>
__device__
inline
half2
operator
*
(
const
half2
&
a
,
const
half2
&
b
)
{
float2
out
{};
out
.
x
=
__half2float
(
a
.
x
)
*
__half2float
(
b
.
x
);
out
.
y
=
__half2float
(
a
.
y
)
*
__half2float
(
b
.
y
);
return
__float22half2_rn
(
out
);
float2
out
{};
out
.
x
=
__half2float
(
a
.
x
)
*
__half2float
(
b
.
x
);
out
.
y
=
__half2float
(
a
.
y
)
*
__half2float
(
b
.
y
);
return
__float22half2_rn
(
out
);
}
template
<
typename
T
>
__device__
inline
T
operator
+
(
const
T
&
a
,
const
T
&
b
);
...
...
@@ -74,70 +74,73 @@ template <typename T>
__device__
inline
T
operator
*
(
const
T
&
a
,
const
T
&
b
);
template
<
>
__device__
inline
half
operator
+
(
const
half
&
a
,
const
half
&
b
)
{
return
__float2half
(
__half2float
(
a
)
+
__half2float
(
b
));
return
__float2half
(
__half2float
(
a
)
+
__half2float
(
b
));
}
template
<
>
__device__
inline
half
&
operator
+=
(
half
&
a
,
const
half
&
b
)
{
a
=
__float2half
(
__half2float
(
a
)
+
__half2float
(
b
));
return
a
;
a
=
__float2half
(
__half2float
(
a
)
+
__half2float
(
b
));
return
a
;
}
template
<
>
__device__
inline
half
operator
-
(
const
half
&
a
,
const
half
&
b
)
{
return
__float2half
(
__half2float
(
a
)
-
__half2float
(
b
));
return
__float2half
(
__half2float
(
a
)
-
__half2float
(
b
));
}
template
<
>
__device__
inline
half
operator
*
(
const
half
&
a
,
const
half
&
b
)
{
return
__float2half
(
__half2float
(
a
)
*
__half2float
(
b
));
return
__float2half
(
__half2float
(
a
)
*
__half2float
(
b
));
}
template
<
>
__device__
inline
half
operator
/
(
const
half
&
a
,
const
half
&
b
)
{
return
__float2half
(
__half2float
(
a
)
/
__half2float
(
b
));
return
__float2half
(
__half2float
(
a
)
/
__half2float
(
b
));
}
#endif
template
<
>
__device__
inline
half
rsqrt
(
const
half
&
x
)
{
#if __CUDA_ARCH__ >= 530
return
hrsqrt
(
x
);
return
hrsqrt
(
x
);
#else
return
__float2half
(
rsqrt
(
__half2float
(
x
)));
return
__float2half
(
rsqrt
(
__half2float
(
x
)));
#endif
}
__device__
inline
kv_half
operator
+
(
const
kv_half
&
a
,
const
kv_half
&
b
)
{
const
half2
a2
=
__halves2half2
(
a
.
key
,
a
.
value
);
const
half2
b2
=
__halves2half2
(
b
.
key
,
b
.
value
);
const
half2
res
=
__hadd2_with_fallback
(
a2
,
b2
);
return
kv_half
(
res
.
x
,
res
.
y
);
const
half2
a2
=
__halves2half2
(
a
.
key
,
a
.
value
);
const
half2
b2
=
__halves2half2
(
b
.
key
,
b
.
value
);
const
half2
res
=
__hadd2_with_fallback
(
a2
,
b2
);
return
kv_half
(
res
.
x
,
res
.
y
);
}
__device__
inline
kv_half2
operator
+
(
const
kv_half2
&
a
,
const
kv_half2
&
b
)
{
return
kv_half2
(
__hadd2_with_fallback
(
a
.
key
,
b
.
key
),
__hadd2_with_fallback
(
a
.
value
,
b
.
value
));
return
kv_half2
(
__hadd2_with_fallback
(
a
.
key
,
b
.
key
),
__hadd2_with_fallback
(
a
.
value
,
b
.
value
));
}
// Helper Functions
template
<
typename
T
>
using
kvp
=
cub
::
KeyValuePair
<
T
,
T
>
;
template
<
typename
T
,
typename
R
,
typename
P
,
int
TPB
>
__device__
inline
void
layerNorm
(
const
kvp
<
R
>&
threadData
,
const
int
ld
,
const
int
offset
,
const
P
*
beta
,
const
P
*
gamma
,
T
*
output
)
{
// Assuming threadData is already divided by ld
using
BlockReduce
=
cub
::
BlockReduce
<
kvp
<
R
>
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
R
mu
;
// mean
__shared__
R
rsigma
;
// 1 / std.dev.
const
auto
sumKV
=
BlockReduce
(
temp_storage
).
Reduce
(
threadData
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
mu
=
sumKV
.
key
;
rsigma
=
rsqrt
(
sumKV
.
value
-
mu
*
mu
);
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
ld
;
i
+=
TPB
)
{
const
int
idx
=
offset
+
i
;
const
R
val
=
output
[
idx
];
const
R
g
(
gamma
[
i
]);
const
R
b
(
beta
[
i
]);
output
[
idx
]
=
g
*
(
val
-
mu
)
*
rsigma
+
b
;
}
__device__
inline
void
layerNorm
(
const
kvp
<
R
>&
threadData
,
const
int
ld
,
const
int
offset
,
const
P
*
beta
,
const
P
*
gamma
,
T
*
output
)
{
// Assuming threadData is already divided by ld
using
BlockReduce
=
cub
::
BlockReduce
<
kvp
<
R
>
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
R
mu
;
// mean
__shared__
R
rsigma
;
// 1 / std.dev.
const
auto
sumKV
=
BlockReduce
(
temp_storage
).
Reduce
(
threadData
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
mu
=
sumKV
.
key
;
rsigma
=
rsqrt
(
sumKV
.
value
-
mu
*
mu
);
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
ld
;
i
+=
TPB
)
{
const
int
idx
=
offset
+
i
;
const
R
val
=
output
[
idx
];
const
R
g
(
gamma
[
i
]);
const
R
b
(
beta
[
i
]);
output
[
idx
]
=
g
*
(
val
-
mu
)
*
rsigma
+
b
;
}
}
#endif // #ifndef COMMON_CUH
paddle/fluid/inference/tensorrt/plugin/common/plugin.h
浏览文件 @
2a9c590b
...
...
@@ -14,8 +14,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#
ifndef PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_
#define PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_
#
pragma once
#include <cuda_runtime.h>
#include <cstring>
#include <iostream>
...
...
@@ -60,4 +60,3 @@ class BaseCreator : public IPluginCreator {
};
}
// namespace nvinfer1
#endif // PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_
paddle/fluid/inference/tensorrt/plugin/lookup_table.cu
0 → 100644
浏览文件 @
2a9c590b
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/inference/tensorrt/plugin/lookup_table.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
char
const
*
PLUGINVERSION
{
"1"
};
char
const
*
LOOKUPTABLEPLUGINNAME
{
"LookupTablePluginDynamic"
};
template
<
typename
T
,
unsigned
TPB
>
__global__
void
lookup_table_kernel
(
int
weight_height
,
int32_t
const
*
inputIds
,
T
const
*
wordEmb
,
int32_t
const
wordSize
,
T
*
output
)
{
// 1. lookup word and token of the block
// blockIdx.x = position in the sequence
// blockIdx.y = batch
// gridDim.x = S
// gridDim.y = B
__shared__
int
wordId
;
int32_t
const
seqPos
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
if
(
threadIdx
.
x
==
0
)
{
wordId
=
inputIds
[
seqPos
];
}
__syncthreads
();
// 2. load word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
int32_t
const
woffset
=
wordId
*
weight_height
;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
int32_t
const
outOffset
=
seqPos
*
weight_height
;
if
(
wordId
>=
0
&&
wordId
<
wordSize
)
{
for
(
int
it
=
threadIdx
.
x
;
it
<
weight_height
;
it
+=
TPB
)
{
T
const
w
(
wordEmb
[
woffset
+
it
]);
output
[
outOffset
+
it
]
=
w
;
}
}
else
{
printf
(
"Error!!!!!!(LookupTablePlugin): ID cannot be lookup "
"table: ID < 0 or ID > max "
);
return
;
}
}
template
<
typename
T
>
int
lookup_table
(
cudaStream_t
stream
,
int
weight_height
,
int
B
,
int
S
,
int32_t
const
*
inputIds
,
T
const
*
wordEmb
,
int32_t
const
wordSize
,
T
*
output
)
{
constexpr
int
tpb
=
256
;
dim3
const
grid
(
S
,
B
,
1
);
dim3
const
block
(
tpb
,
1
,
1
);
lookup_table_kernel
<
T
,
tpb
><<<
grid
,
block
,
0
,
stream
>>>
(
weight_height
,
inputIds
,
wordEmb
,
wordSize
,
output
);
return
0
;
}
// Static class fields initialization
nvinfer1
::
PluginFieldCollection
LookupTablePluginDynamicCreator
::
mFC
{};
std
::
vector
<
nvinfer1
::
PluginField
>
LookupTablePluginDynamicCreator
::
mPluginAttributes
;
LookupTablePluginDynamic
::
LookupTablePluginDynamic
(
nvinfer1
::
DataType
const
type
,
void
*
weight_dev
,
int32_t
weight_size
,
int32_t
width
)
:
mType
(
type
),
mWeightDev
(
weight_dev
),
mWeightSize
(
weight_size
),
mWeightWidth
(
width
)
{}
LookupTablePluginDynamic
::
LookupTablePluginDynamic
(
void
const
*
data
,
size_t
length
)
{
// Deserialize in the same order as serialization
deserialize_value
(
&
data
,
&
length
,
&
mType
);
deserialize_value
(
&
data
,
&
length
,
&
mWeightSize
);
deserialize_value
(
&
data
,
&
length
,
&
mWeightWidth
);
char
const
*
d
=
static_cast
<
char
const
*>
(
data
);
cudaMalloc
(
&
mWeightDev
,
mWeightSize
*
sizeof
(
mType
));
cudaMemcpy
(
mWeightDev
,
d
,
mWeightSize
*
sizeof
(
mType
),
cudaMemcpyHostToDevice
);
}
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
LookupTablePluginDynamic
::
clone
()
const
noexcept
{
auto
p
=
new
LookupTablePluginDynamic
(
mType
,
mWeightDev
,
mWeightSize
,
mWeightWidth
);
p
->
setPluginNamespace
(
mNamespace
.
c_str
());
return
p
;
}
nvinfer1
::
DimsExprs
LookupTablePluginDynamic
::
getOutputDimensions
(
int32_t
outputIndex
,
nvinfer1
::
DimsExprs
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
noexcept
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
inputs
[
0
].
nbDims
+
1
;
for
(
int
i
=
0
;
i
<
inputs
[
0
].
nbDims
;
++
i
)
{
ret
.
d
[
i
]
=
inputs
[
0
].
d
[
i
];
}
ret
.
d
[
inputs
[
0
].
nbDims
]
=
exprBuilder
.
constant
(
mWeightWidth
);
return
ret
;
}
bool
LookupTablePluginDynamic
::
supportsFormatCombination
(
int32_t
pos
,
nvinfer1
::
PluginTensorDesc
const
*
inOut
,
int32_t
nbInputs
,
int32_t
nbOutputs
)
noexcept
{
nvinfer1
::
PluginTensorDesc
const
&
desc
=
inOut
[
pos
];
if
(
desc
.
format
!=
nvinfer1
::
TensorFormat
::
kLINEAR
)
{
return
false
;
}
if
(
pos
==
0
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
;
}
if
(
pos
==
1
)
{
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
;
}
else
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
;
}
}
}
void
LookupTablePluginDynamic
::
configurePlugin
(
nvinfer1
::
DynamicPluginTensorDesc
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
DynamicPluginTensorDesc
const
*
outputs
,
int32_t
nbOutputs
)
noexcept
{}
size_t
LookupTablePluginDynamic
::
getWorkspaceSize
(
nvinfer1
::
PluginTensorDesc
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
PluginTensorDesc
const
*
outputs
,
int32_t
nbOutputs
)
const
noexcept
{
return
0
;
}
int32_t
LookupTablePluginDynamic
::
enqueue
(
nvinfer1
::
PluginTensorDesc
const
*
inputDesc
,
nvinfer1
::
PluginTensorDesc
const
*
outputDesc
,
void
const
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
{
int32_t
const
batchSize
=
inputDesc
->
dims
.
d
[
0
];
int32_t
S
;
if
(
inputDesc
->
dims
.
nbDims
==
1
)
{
S
=
1
;
}
else
{
S
=
inputDesc
->
dims
.
d
[
1
];
}
int32_t
mWeightHeight
=
mWeightSize
/
mWeightWidth
;
int32_t
status
=
STATUS_FAILURE
;
auto
const
inputIds
=
static_cast
<
int32_t
const
*>
(
inputs
[
0
]);
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
auto
const
Weight
=
static_cast
<
const
float
*>
(
mWeightDev
);
status
=
lookup_table
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mWeightWidth
),
batchSize
,
S
,
inputIds
,
Weight
,
mWeightHeight
,
output
);
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
auto
const
Weight
=
static_cast
<
const
half
*>
(
mWeightDev
);
status
=
lookup_table
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mWeightWidth
),
batchSize
,
S
,
inputIds
,
Weight
,
mWeightHeight
,
output
);
}
return
status
;
}
// IPluginV2Ext Methods
nvinfer1
::
DataType
LookupTablePluginDynamic
::
getOutputDataType
(
int32_t
index
,
nvinfer1
::
DataType
const
*
inputTypes
,
int32_t
nbInputs
)
const
noexcept
{
if
(
index
==
0
)
{
assert
(
mType
==
nvinfer1
::
DataType
::
kHALF
||
mType
==
nvinfer1
::
DataType
::
kFLOAT
);
return
mType
;
}
}
// IPluginV2 Methods
char
const
*
LookupTablePluginDynamic
::
getPluginType
()
const
noexcept
{
return
LOOKUPTABLEPLUGINNAME
;
}
char
const
*
LookupTablePluginDynamic
::
getPluginVersion
()
const
noexcept
{
return
PLUGINVERSION
;
}
int32_t
LookupTablePluginDynamic
::
getNbOutputs
()
const
noexcept
{
return
1
;
}
int32_t
LookupTablePluginDynamic
::
initialize
()
noexcept
{
return
0
;
}
void
LookupTablePluginDynamic
::
terminate
()
noexcept
{
cudaFree
(
mWeightDev
);
}
size_t
LookupTablePluginDynamic
::
getSerializationSize
()
const
noexcept
{
size_t
const
wordSize
=
getElementSize
(
mType
);
return
sizeof
(
mType
)
//
+
sizeof
(
mWeightSize
)
//
+
sizeof
(
mWeightWidth
)
//
+
wordSize
*
mWeightSize
;
//
}
void
LookupTablePluginDynamic
::
serialize
(
void
*
buffer
)
const
noexcept
{
serialize_value
(
&
buffer
,
mType
);
serialize_value
(
&
buffer
,
mWeightSize
);
serialize_value
(
&
buffer
,
mWeightWidth
);
char
*
d
=
static_cast
<
char
*>
(
buffer
);
size_t
const
wordSize
=
getElementSize
(
mType
);
serFromDev
(
&
d
,
static_cast
<
char
*>
(
mWeightDev
),
mWeightSize
*
wordSize
);
}
void
LookupTablePluginDynamic
::
destroy
()
noexcept
{
// This gets called when the network containing plugin is destroyed
delete
this
;
}
void
LookupTablePluginDynamic
::
setPluginNamespace
(
char
const
*
libNamespace
)
noexcept
{
mNamespace
=
libNamespace
;
}
char
const
*
LookupTablePluginDynamic
::
getPluginNamespace
()
const
noexcept
{
return
mNamespace
.
c_str
();
}
LookupTablePluginDynamicCreator
::
LookupTablePluginDynamicCreator
()
{}
char
const
*
LookupTablePluginDynamicCreator
::
getPluginName
()
const
noexcept
{
return
LOOKUPTABLEPLUGINNAME
;
}
char
const
*
LookupTablePluginDynamicCreator
::
getPluginVersion
()
const
noexcept
{
return
PLUGINVERSION
;
}
nvinfer1
::
PluginFieldCollection
const
*
LookupTablePluginDynamicCreator
::
getFieldNames
()
noexcept
{
return
&
mFC
;
}
bool
initializeFields
(
nvinfer1
::
PluginFieldCollection
const
*
fc
,
nvinfer1
::
Weights
*
weight
,
int32_t
&
mWeightWidth
)
{
// NOLINT
bool
output_fp16
=
false
;
for
(
int32_t
i
=
0
;
i
<
fc
->
nbFields
;
i
++
)
{
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"lookup_table_weight"
)
==
0
)
{
weight
->
values
=
fc
->
fields
[
i
].
data
;
weight
->
count
=
fc
->
fields
[
i
].
length
;
weight
->
type
=
fieldTypeToDataType
(
fc
->
fields
[
i
].
type
);
}
if
(
field_name
.
compare
(
"lookup_table_weight_width"
)
==
0
)
{
assert
(
fc
->
fields
[
i
].
type
==
nvinfer1
::
PluginFieldType
::
kINT32
);
mWeightWidth
=
const_cast
<
int32_t
*>
(
static_cast
<
int32_t
const
*>
(
fc
->
fields
[
i
].
data
))[
0
];
// NOLINT
}
if
(
field_name
.
compare
(
"output_fp16"
)
==
0
)
{
assert
(
fc
->
fields
[
i
].
type
==
nvinfer1
::
PluginFieldType
::
kINT32
);
output_fp16
=
static_cast
<
int32_t
const
*>
(
fc
->
fields
[
i
].
data
)[
0
]
!=
0
;
}
}
return
output_fp16
;
}
nvinfer1
::
IPluginV2
*
LookupTablePluginDynamicCreator
::
createPlugin
(
char
const
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
noexcept
{
nvinfer1
::
Weights
weight
;
int32_t
mWeightWidth
;
bool
output_fp16
=
initializeFields
(
fc
,
&
weight
,
mWeightWidth
);
nvinfer1
::
DataType
type
;
if
(
output_fp16
)
{
type
=
nvinfer1
::
DataType
::
kHALF
;
}
else
{
type
=
nvinfer1
::
DataType
::
kFLOAT
;
}
WeightsWithOwnership
mWeight
;
mWeight
.
convertAndCopy
(
weight
,
type
);
void
*
cudaMem
{
nullptr
};
cudaMalloc
(
&
cudaMem
,
getWeightsSize
(
mWeight
,
type
));
cudaMemcpy
(
cudaMem
,
mWeight
.
values
,
getWeightsSize
(
mWeight
,
type
),
cudaMemcpyHostToDevice
);
LookupTablePluginDynamic
*
p
=
new
LookupTablePluginDynamic
(
type
,
cudaMem
,
mWeight
.
count
,
mWeightWidth
);
return
p
;
}
nvinfer1
::
IPluginV2
*
LookupTablePluginDynamicCreator
::
deserializePlugin
(
char
const
*
name
,
void
const
*
serialData
,
size_t
serialLength
)
noexcept
{
return
new
LookupTablePluginDynamic
(
serialData
,
serialLength
);
}
void
LookupTablePluginDynamicCreator
::
setPluginNamespace
(
char
const
*
libNamespace
)
noexcept
{
mNamespace
=
libNamespace
;
}
char
const
*
LookupTablePluginDynamicCreator
::
getPluginNamespace
()
const
noexcept
{
return
mNamespace
.
c_str
();
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/lookup_table.h
0 → 100644
浏览文件 @
2a9c590b
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
#include "NvInferPlugin.h"
#include "NvInferRuntime.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
LookupTablePluginDynamic
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
LookupTablePluginDynamic
(
nvinfer1
::
DataType
const
type
,
void
*
weight_dev
,
int32_t
weight_size
,
int32_t
width
);
LookupTablePluginDynamic
(
void
const
*
data
,
size_t
length
);
// It doesn't make sense to make EmbLayerNormVarSeqlenPlugin without
// arguments, so we delete default constructor.
LookupTablePluginDynamic
()
=
delete
;
// IPluginV2DynamicExt Methods
bool
supportsFormatCombination
(
int32_t
pos
,
nvinfer1
::
PluginTensorDesc
const
*
inOut
,
int32_t
nbInputs
,
int32_t
nbOutputs
)
noexcept
override
;
size_t
getWorkspaceSize
(
nvinfer1
::
PluginTensorDesc
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
PluginTensorDesc
const
*
outputs
,
int32_t
nbOutputs
)
const
noexcept
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int32_t
index
,
nvinfer1
::
DataType
const
*
inputTypes
,
int32_t
nbInputs
)
const
noexcept
override
;
// IPluginV2 Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
noexcept
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int32_t
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
noexcept
override
;
void
configurePlugin
(
nvinfer1
::
DynamicPluginTensorDesc
const
*
in
,
int32_t
nbInputs
,
nvinfer1
::
DynamicPluginTensorDesc
const
*
out
,
int32_t
nbOutputs
)
noexcept
override
;
char
const
*
getPluginType
()
const
noexcept
override
;
int32_t
getNbOutputs
()
const
noexcept
override
;
size_t
getSerializationSize
()
const
noexcept
override
;
void
serialize
(
void
*
buffer
)
const
noexcept
override
;
void
destroy
()
noexcept
override
;
char
const
*
getPluginNamespace
()
const
noexcept
override
;
void
setPluginNamespace
(
char
const
*
pluginNamespace
)
noexcept
override
;
int32_t
enqueue
(
nvinfer1
::
PluginTensorDesc
const
*
inputDesc
,
nvinfer1
::
PluginTensorDesc
const
*
outputDesc
,
void
const
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
override
;
int32_t
initialize
()
noexcept
override
;
void
terminate
()
noexcept
override
;
char
const
*
getPluginVersion
()
const
noexcept
override
;
protected:
std
::
string
mNamespace
;
nvinfer1
::
DataType
mType
;
void
*
mWeightDev
{
nullptr
};
int32_t
mWeightSize
;
int32_t
mWeightWidth
;
};
class
LookupTablePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
LookupTablePluginDynamicCreator
();
char
const
*
getPluginName
()
const
noexcept
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
noexcept
override
;
void
setPluginNamespace
(
char
const
*
pluginNamespace
)
noexcept
override
;
char
const
*
getPluginNamespace
()
const
noexcept
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
char
const
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
noexcept
override
;
char
const
*
getPluginVersion
()
const
noexcept
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
char
const
*
name
,
void
const
*
serialData
,
size_t
serialLength
)
noexcept
override
;
protected:
static
nvinfer1
::
PluginFieldCollection
mFC
;
static
std
::
vector
<
nvinfer1
::
PluginField
>
mPluginAttributes
;
std
::
string
mNamespace
;
};
REGISTER_TRT_PLUGIN_V2
(
LookupTablePluginDynamicCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu
浏览文件 @
2a9c590b
...
...
@@ -21,10 +21,7 @@
#include <vector>
#include "NvInfer.h"
#include "common/bertCommon.h"
#include "common/common.cuh"
#include "common/plugin.h"
#include "common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
namespace
paddle
{
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu
浏览文件 @
2a9c590b
...
...
@@ -21,10 +21,7 @@
#include <vector>
#include "NvInfer.h"
#include "common/bertCommon.h"
#include "common/common.cuh"
#include "common/plugin.h"
#include "common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
namespace
paddle
{
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
浏览文件 @
2a9c590b
...
...
@@ -19,7 +19,6 @@
#include <cstring>
#include <vector>
#include "NvInfer.h"
#include "common/serialize.h"
namespace
paddle
{
namespace
inference
{
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
浏览文件 @
2a9c590b
...
...
@@ -18,7 +18,10 @@
#include <cuda.h>
#include "NvInferPlugin.h"
#include "NvInferRuntime.h"
#include "common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py
浏览文件 @
2a9c590b
...
...
@@ -228,7 +228,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
max_batch_size
=
4
,
workspace_size
=
102400
,
min_subgraph_size
=
0
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Float32
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Half
,
use_static
=
False
,
use_calib_mode
=
False
)
yield
config
,
[
'fused_embedding_eltwise_layernorm'
],
(
1e-5
,
1e-5
)
...
...
@@ -238,7 +238,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
max_batch_size
=
4
,
workspace_size
=
102400
,
min_subgraph_size
=
0
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Float32
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Half
,
use_static
=
False
,
use_calib_mode
=
False
)
if
program_config
.
ops
[
0
].
type
==
'lookup_table'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录