Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a5ef246c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a5ef246c
编写于
9月 18, 2020
作者:
P
Pei Yang
提交者:
GitHub
9月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize emb_eltwise_layernorm_plugin and support fp16 (#27128)
上级
4c5cfdea
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
287 addition
and
124 deletion
+287
-124
cmake/cuda.cmake
cmake/cuda.cmake
+3
-0
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+3
-3
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
...inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
+133
-81
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
.../inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
+144
-34
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc
...nce/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc
+4
-6
未找到文件。
cmake/cuda.cmake
浏览文件 @
a5ef246c
...
...
@@ -107,6 +107,9 @@ function(select_nvcc_arch_flags out_variable)
elseif
(
${
CUDA_ARCH_NAME
}
STREQUAL
"Maxwell"
)
set
(
cuda_arch_bin
"50"
)
elseif
(
${
CUDA_ARCH_NAME
}
STREQUAL
"Pascal"
)
if
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
LESS 10.0
)
add_definitions
(
"-DSUPPORTS_CUDA_FP16"
)
endif
()
set
(
cuda_arch_bin
"60 61"
)
elseif
(
${
CUDA_ARCH_NAME
}
STREQUAL
"Volta"
)
if
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
LESS 10.0
)
...
...
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
a5ef246c
...
...
@@ -80,10 +80,10 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
<
float
>
(
auto
use_fp16
=
engine_
->
WithFp16
()
;
auto
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
(
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
eps
);
eps
,
use_fp16
);
layer
=
engine_
->
AddPluginV2
(
input_ids
.
data
(),
input_num
,
plugin
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
...
...
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
浏览文件 @
a5ef246c
...
...
@@ -32,13 +32,34 @@ namespace plugin {
#if IS_TRT_VERSION_GE(6000)
template
<
typename
T
>
int
EmbEltwiseLayernormPluginDynamic
<
T
>::
initialize
()
{
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::~
EmbEltwiseLayernormPluginDynamicImpl
()
{
this
->
terminate
();
}
inline
half
fp32tofp16
(
float
x
)
{
return
static_cast
<
half
>
(
x
);
}
template
<
typename
T
>
int
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
initialize
()
{
embs_gpu_
.
resize
(
embs_
.
size
());
for
(
int
i
=
0
;
i
<
embs_
.
size
();
i
++
)
{
if
(
embs_
[
i
])
{
cudaMalloc
(
&
embs_gpu_
[
i
],
sizeof
(
float
)
*
emb_sizes_
[
i
]);
cudaMemcpy
(
embs_gpu_
[
i
],
embs_
[
i
],
emb_sizes_
[
i
]
*
sizeof
(
float
),
T
*
host_ptr
;
auto
size
=
emb_sizes_
[
i
];
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
host_ptr
=
new
T
[
size
];
std
::
transform
(
embs_
[
i
],
(
embs_
[
i
]
+
size
),
host_ptr
,
fp32tofp16
);
}
else
{
host_ptr
=
reinterpret_cast
<
T
*>
(
embs_
[
i
]);
}
cudaMalloc
(
&
embs_gpu_
[
i
],
sizeof
(
T
)
*
size
);
cudaMemcpy
(
embs_gpu_
[
i
],
host_ptr
,
size
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
delete
[]
host_ptr
;
}
}
}
...
...
@@ -53,11 +74,105 @@ int EmbEltwiseLayernormPluginDynamic<T>::initialize() {
cudaMemcpyHostToDevice
);
}
int
input_num
=
embs_
.
size
();
in_ptr_tensor_
.
Resize
({
input_num
});
emb_ptr_tensor_
.
Resize
({
input_num
});
cudaGetDevice
(
&
device_id_
);
auto
emb_ptr_gpu_d
=
emb_ptr_tensor_
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id_
));
cudaMemcpy
(
emb_ptr_gpu_d
,
embs_gpu_
.
data
(),
sizeof
(
uintptr_t
)
*
input_num
,
cudaMemcpyHostToDevice
);
return
0
;
}
template
<
typename
T
>
nvinfer1
::
DimsExprs
EmbEltwiseLayernormPluginDynamic
<
T
>::
getOutputDimensions
(
void
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
terminate
()
{
for
(
int
i
=
0
;
i
<
embs_gpu_
.
size
();
++
i
)
{
if
(
embs_gpu_
[
i
])
{
cudaFree
(
embs_gpu_
[
i
]);
embs_gpu_
[
i
]
=
nullptr
;
}
}
if
(
bias_gpu_
)
{
cudaFree
(
bias_gpu_
);
bias_gpu_
=
nullptr
;
}
if
(
scale_gpu_
)
{
cudaFree
(
scale_gpu_
);
scale_gpu_
=
nullptr
;
}
}
template
<
typename
T
>
int
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
id_dims
=
input_desc
[
0
].
dims
;
int
batch
=
id_dims
.
d
[
0
];
int
seq_len
=
id_dims
.
d
[
1
];
int
input_num
=
embs_
.
size
();
auto
in_ptr_gpu_d
=
in_ptr_tensor_
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id_
));
auto
emb_ptr_gpu_d
=
emb_ptr_tensor_
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id_
));
auto
new_input_ptr
=
reinterpret_cast
<
uintptr_t
>
(
inputs
[
0
]);
if
(
old_input_ptr_
!=
new_input_ptr
)
{
old_input_ptr_
=
new_input_ptr
;
cudaMemcpyAsync
(
in_ptr_gpu_d
,
reinterpret_cast
<
const
void
*>
(
inputs
),
sizeof
(
uintptr_t
)
*
input_num
,
cudaMemcpyHostToDevice
,
stream
);
}
auto
out_type
=
output_desc
[
0
].
type
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
PADDLE_ENFORCE_EQ
(
out_type
==
nvinfer1
::
DataType
::
kFLOAT
,
true
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only support fp32 input."
));
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
PADDLE_ENFORCE_EQ
(
out_type
==
nvinfer1
::
DataType
::
kHALF
,
true
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only support fp16 input."
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."
));
}
auto
*
output_d
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
operators
::
math
::
EmbEltwiseLayerNormFunctor
<
T
>
emb_eltwise_layernorm_func
;
emb_eltwise_layernorm_func
(
batch
,
seq_len
,
hidden_size_
,
in_ptr_gpu_d
,
scale_gpu_
,
bias_gpu_
,
emb_ptr_gpu_d
,
output_d
,
eps_
,
input_num
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
template
class
EmbEltwiseLayernormPluginDynamicImpl
<
float
>;
#ifdef SUPPORTS_CUDA_FP16
template
class
EmbEltwiseLayernormPluginDynamicImpl
<
half
>;
#endif // SUPPORTS_CUDA_FP16
int
EmbEltwiseLayernormPluginDynamic
::
initialize
()
{
impl_
->
initialize
();
return
0
;
}
void
EmbEltwiseLayernormPluginDynamic
::
terminate
()
{
impl_
->
terminate
();
}
nvinfer1
::
DimsExprs
EmbEltwiseLayernormPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
// NOLINT
PADDLE_ENFORCE_EQ
(
output_index
,
0
,
...
...
@@ -76,18 +191,7 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
return
ret
;
}
template
<
typename
T
>
void
EmbEltwiseLayernormPluginDynamic
<
T
>::
terminate
()
{
for
(
auto
ptr
:
embs_gpu_
)
{
if
(
ptr
)
cudaFree
(
ptr
);
}
if
(
bias_gpu_
)
cudaFree
(
bias_gpu_
);
if
(
scale_gpu_
)
cudaFree
(
scale_gpu_
);
}
template
<
typename
T
>
bool
EmbEltwiseLayernormPluginDynamic
<
T
>::
supportsFormatCombination
(
bool
EmbEltwiseLayernormPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
{
PADDLE_ENFORCE_NOT_NULL
(
...
...
@@ -98,6 +202,11 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs."
,
nb_outputs
));
PADDLE_ENFORCE_EQ
(
nb_outputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs."
,
nb_outputs
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
...
...
@@ -122,7 +231,7 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
}
if
(
pos
==
all_nums
-
1
)
{
if
(
sizeof
(
T
)
==
sizeof
(
float
)
)
{
if
(
with_fp16_
==
false
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
;
}
else
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
;
...
...
@@ -131,84 +240,27 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
return
false
;
}
template
<
typename
T
>
nvinfer1
::
DataType
EmbEltwiseLayernormPluginDynamic
<
T
>::
getOutputDataType
(
nvinfer1
::
DataType
EmbEltwiseLayernormPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
return
nvinfer1
::
DataType
::
kFLOAT
;
if
(
with_fp16_
)
return
nvinfer1
::
DataType
::
kHALF
;
else
return
nvinfer1
::
DataType
::
kFLOAT
;
}
template
<
typename
T
>
int
EmbEltwiseLayernormPluginDynamic
<
T
>::
enqueue
(
int
EmbEltwiseLayernormPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
id_dims
=
input_desc
[
0
].
dims
;
int
batch
=
id_dims
.
d
[
0
];
int
seq_len
=
id_dims
.
d
[
1
];
int
input_num
=
embs_
.
size
();
framework
::
Tensor
in_ptr_tensor
,
emb_ptr_tensor
;
int
device_id
;
cudaGetDevice
(
&
device_id
);
in_ptr_tensor
.
Resize
({
input_num
});
emb_ptr_tensor
.
Resize
({
input_num
});
int64_t
*
in_ptr_gpu_d
=
in_ptr_tensor
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id
));
int64_t
*
emb_ptr_gpu_d
=
emb_ptr_tensor
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id
));
std
::
vector
<
uintptr_t
>
in_ptr
,
emb_ptr
;
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
in_ptr
.
push_back
(
reinterpret_cast
<
uintptr_t
>
(
inputs
[
i
]));
emb_ptr
.
push_back
(
reinterpret_cast
<
uintptr_t
>
(
embs_gpu_
[
i
]));
}
cudaMemcpyAsync
(
in_ptr_gpu_d
,
in_ptr
.
data
(),
sizeof
(
int64_t
)
*
input_num
,
cudaMemcpyHostToDevice
,
stream
);
cudaMemcpyAsync
(
emb_ptr_gpu_d
,
emb_ptr
.
data
(),
sizeof
(
int64_t
)
*
input_num
,
cudaMemcpyHostToDevice
,
stream
);
auto
out_type
=
output_desc
[
0
].
type
;
const
unsigned
tpb
=
256
;
const
dim3
grid
(
seq_len
,
batch
,
1
);
const
dim3
block
(
tpb
,
1
,
1
);
if
(
sizeof
(
T
)
==
sizeof
(
float
))
{
PADDLE_ENFORCE_EQ
(
out_type
==
nvinfer1
::
DataType
::
kFLOAT
,
true
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only support fp32 input."
));
}
else
if
(
sizeof
(
T
)
==
sizeof
(
int16_t
))
{
PADDLE_ENFORCE_EQ
(
out_type
==
nvinfer1
::
DataType
::
kHALF
,
true
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only support fp16 input."
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."
));
}
T
*
output_d
=
static_cast
<
T
*>
(
outputs
[
0
]);
operators
::
math
::
EmbEltwiseLayerNormFunctor
<
T
>
emb_eltwise_layernorm_func
;
emb_eltwise_layernorm_func
(
batch
,
seq_len
,
hidden_size_
,
in_ptr_gpu_d
,
scale_gpu_
,
bias_gpu_
,
emb_ptr_gpu_d
,
output_d
,
eps_
,
input_num
,
stream
);
impl_
->
enqueue
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
template
class
EmbEltwiseLayernormPluginDynamic
<
float
>;
#ifdef SUPPORTS_CUDA_FP16
template
class
EmbEltwiseLayernormPluginDynamic
<
half
>;
#endif // SUPPORTS_CUDA_FP16
#endif
}
// namespace plugin
...
...
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
浏览文件 @
a5ef246c
...
...
@@ -27,14 +27,76 @@ namespace tensorrt {
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
EmbEltwiseLayernormPluginDynamicImplBase
{
public:
EmbEltwiseLayernormPluginDynamicImplBase
()
{}
virtual
~
EmbEltwiseLayernormPluginDynamicImplBase
()
{}
virtual
int
initialize
()
=
0
;
virtual
void
terminate
()
=
0
;
virtual
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
};
template
<
typename
T
>
class
EmbEltwiseLayernormPluginDynamicImpl
:
public
EmbEltwiseLayernormPluginDynamicImplBase
{
public:
explicit
EmbEltwiseLayernormPluginDynamicImpl
(
std
::
vector
<
float
*>
input_embs
,
float
*
bias
,
float
*
scale
,
std
::
vector
<
int
>
emb_sizes
,
int
bias_size
,
int
scale_size
,
int
hidden_size
,
float
eps
)
:
embs_
(
input_embs
),
bias_
(
bias
),
scale_
(
scale
),
emb_sizes_
(
emb_sizes
),
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
hidden_size_
(
hidden_size
),
eps_
(
eps
)
{}
~
EmbEltwiseLayernormPluginDynamicImpl
();
int
initialize
();
void
terminate
();
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
);
private:
std
::
vector
<
float
*>
embs_
;
float
*
bias_
{
nullptr
};
float
*
scale_
{
nullptr
};
// data on devices
float
*
bias_gpu_
{
nullptr
};
float
*
scale_gpu_
{
nullptr
};
std
::
vector
<
T
*>
embs_gpu_
;
std
::
vector
<
int
>
emb_sizes_
;
int
bias_size_
;
int
scale_size_
;
int
hidden_size_
;
float
eps_
;
framework
::
Tensor
in_ptr_tensor_
,
emb_ptr_tensor_
;
int
device_id_
{
0
};
uintptr_t
old_input_ptr_
{
0
};
};
class
EmbEltwiseLayernormPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
EmbEltwiseLayernormPluginDynamic
(
std
::
vector
<
float
*>
input_embs
,
float
*
bias
,
float
*
scale
,
std
::
vector
<
int
>
emb_sizes
,
int
bias_size
,
int
scale_size
,
int
hidden_size
,
float
eps
)
int
hidden_size
,
float
eps
,
bool
with_fp16
)
:
embs_
(
input_embs
),
bias_
(
bias
),
scale_
(
scale
),
...
...
@@ -42,51 +104,81 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
hidden_size_
(
hidden_size
),
eps_
(
eps
)
{}
eps_
(
eps
),
with_fp16_
(
with_fp16
),
own_host_buff_
(
false
)
{
if
(
with_fp16
)
{
#ifdef SUPPORTS_CUDA_FP16
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
half
>
(
embs_
,
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupported data type, current GPU doesn't support half."
));
#endif // SUPPORTS_CUDA_FP16
}
else
{
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
float
>
(
embs_
,
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
}
}
EmbEltwiseLayernormPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{
size_t
serial_length
)
:
own_host_buff_
(
true
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
emb_sizes_
);
embs_gpu_
.
resize
(
emb_sizes_
.
size
());
embs_
.
resize
(
emb_sizes_
.
size
());
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
cudaMalloc
(
&
embs_gpu_
[
i
],
sizeof
(
float
)
*
emb_sizes_
[
i
]);
cudaMemcpy
(
embs_gpu_
[
i
],
serial_data
,
emb_sizes_
[
i
]
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
auto
size
=
emb_sizes_
[
i
];
auto
ptr
=
new
float
[
size
];
memcpy
(
ptr
,
serial_data
,
sizeof
(
float
)
*
size
);
embs_
[
i
]
=
ptr
;
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
emb_sizes_
[
i
]
*
sizeof
(
float
);
serial_length
-=
emb_sizes_
[
i
]
*
sizeof
(
float
);
embs_
[
i
]
=
nullptr
;
}
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
float
)
*
bias_size_
);
cudaMemcpy
(
bias_gpu_
,
serial_data
,
bias_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
bias_
=
nullptr
;
if
(
bias_size_
)
{
bias_
=
new
float
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
float
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
float
);
serial_length
-=
bias_size_
*
sizeof
(
float
);
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
float
)
*
scale_size_
);
cudaMemcpy
(
scale_gpu_
,
serial_data
,
scale_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
scale_
=
nullptr
;
if
(
scale_size_
)
{
scale_
=
new
float
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
float
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
float
);
serial_length
-=
scale_size_
*
sizeof
(
float
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
hidden_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
if
(
with_fp16_
)
{
#ifdef SUPPORTS_CUDA_FP16
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
half
>
(
embs_
,
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupported data type, current GPU doesn't support half."
));
#endif // SUPPORTS_CUDA_FP16
}
else
{
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
float
>
(
embs_
,
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
}
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
auto
ptr
=
new
EmbEltwiseLayernormPluginDynamic
(
embs_
,
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
ptr
->
embs_gpu_
=
embs_gpu_
;
ptr
->
bias_gpu_
=
bias_gpu_
;
ptr
->
scale_gpu_
=
scale_gpu_
;
eps_
,
with_fp16_
);
return
ptr
;
}
...
...
@@ -95,6 +187,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
;
void
terminate
()
override
;
size_t
getSerializationSize
()
const
override
{
int
sum_num
=
0
;
...
...
@@ -110,24 +203,32 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
float
);
sum_num
+=
SerializedSize
(
hidden_size_
);
sum_num
+=
SerializedSize
(
eps_
);
//
sum_num += SerializedSize(with_fp16_);
sum_num
+=
SerializedSize
(
with_fp16_
);
return
sum_num
;
}
void
terminate
()
override
;
void
serialize
(
void
*
buffer
)
const
override
{
// SerializeValue(&buffer, with_fp16_);
SerializeValue
(
&
buffer
,
emb_sizes_
);
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
SerializeCudaPointer
(
&
buffer
,
embs_gpu_
[
i
],
emb_sizes_
[
i
]);
auto
size
=
emb_sizes_
[
i
];
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
SerializeValue
(
&
buffer
,
embs_
[
i
][
j
]);
}
}
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
SerializeCudaPointer
(
&
buffer
,
bias_gpu_
,
bias_size_
);
SerializeCudaPointer
(
&
buffer
,
scale_gpu_
,
scale_size_
);
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
bias_
[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
scale_
[
i
]);
}
SerializeValue
(
&
buffer
,
hidden_size_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
...
...
@@ -158,23 +259,33 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
override
;
void
destroy
()
override
{
delete
this
;
}
void
destroy
()
override
{
if
(
own_host_buff_
)
{
for
(
auto
ptr
:
embs_
)
{
delete
[]
ptr
;
}
delete
[]
bias_
;
delete
[]
scale_
;
}
delete
impl_
;
delete
this
;
}
private:
std
::
vector
<
float
*>
embs_
;
float
*
bias_
;
float
*
scale_
;
// data on devices
float
*
bias_gpu_
;
float
*
scale_gpu_
;
std
::
vector
<
float
*>
embs_gpu_
;
std
::
vector
<
int
>
emb_sizes_
;
int
bias_size_
;
int
scale_size_
;
int
hidden_size_
;
float
eps_
;
bool
with_fp16_
;
bool
own_host_buff_
{
false
};
EmbEltwiseLayernormPluginDynamicImplBase
*
impl_
{
nullptr
};
};
class
EmbEltwiseLayernormPluginV2Creator
:
public
nvinfer1
::
IPluginCreator
{
...
...
@@ -198,8 +309,7 @@ class EmbEltwiseLayernormPluginV2Creator : public nvinfer1::IPluginCreator {
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
{
return
new
EmbEltwiseLayernormPluginDynamic
<
float
>
(
serial_data
,
serial_length
);
return
new
EmbEltwiseLayernormPluginDynamic
(
serial_data
,
serial_length
);
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
{
...
...
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc
浏览文件 @
a5ef246c
...
...
@@ -151,7 +151,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
run
(
config
,
&
out_data
);
// serialize
run
(
*
config_deser
,
&
out_data
);
// deserialize
for
(
size_t
i
=
0
;
i
<
out_data
.
size
();
i
++
)
{
EXPECT_NEAR
(
result
[
i
],
out_data
[
i
],
1e-
6
);
EXPECT_NEAR
(
result
[
i
],
out_data
[
i
],
1e-
2
);
}
}
...
...
@@ -159,13 +159,11 @@ TEST(AnalysisPredictor, no_fp16) {
std
::
vector
<
float
>
result
=
{
0.597841
,
0.219972
,
0.182187
};
trt_ernie
(
false
,
result
);
}
TEST
(
AnalysisPredictor
,
fp16
)
{
#ifdef SUPPORTS_CUDA_FP16
std
::
vector
<
float
>
result
=
{
0.598336
,
0.219558
,
0.182106
};
TEST
(
AnalysisPredictor
,
fp16
)
{
std
::
vector
<
float
>
result
=
{
0.59923654
,
0.21923761
,
0.18152587
};
trt_ernie
(
true
,
result
);
#endif
}
#endif // SUPPORTS_CUDA_FP16
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录