Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2ca3fe5d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ca3fe5d
编写于
8月 25, 2020
作者:
Z
zlsh80826
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multihead att plugin
上级
954ebda1
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
349 addition
and
25 deletion
+349
-25
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+130
-22
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+2
-0
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+4
-3
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+2
-0
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+1
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+1
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu
+85
-0
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h
+120
-0
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+3
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
2ca3fe5d
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
namespace
paddle
{
...
...
@@ -30,7 +31,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
// Declare inputs
// Shouble be a 5 dims tensor.
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Input"
).
front
());
auto
*
input_bias_qk
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"BiasQK"
).
front
());
// fc weights and fc bias
auto
weight_name
=
op_desc
.
Input
(
"W"
).
front
();
...
...
@@ -65,14 +65,124 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
}
};
tranpose_weight
(
weight_data_tmp
.
data
(),
weight_data
,
m
,
n
);
int
head_number
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"head_number"
));
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
#ifdef USE_NVINFER_PLUGIN
int
head_size
=
hidden
/
head_number
;
// [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin]
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
const
int
HNH
=
H
*
N
*
H
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
hnh
=
0
;
hnh
<
HNH
;
++
hnh
)
{
dst
[
n
*
3
*
HNH
+
i
*
HNH
+
hnh
]
=
src
[
i
*
N
*
HNH
+
n
*
HNH
+
hnh
];
}
}
}
};
// [3, N, H] -> [N, 3, H]
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
dst
[
n
*
3
*
H
+
i
*
H
+
h
]
=
src
[
i
*
N
*
H
+
n
*
H
+
h
];
}
}
}
};
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
head_number
,
head_size
);
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
int32_t
>
(
weight_t
->
numel
())};
std
::
vector
<
float
>
bias_data_tmp
;
bias_data_tmp
.
reserve
(
bias_t
->
numel
());
memcpy
(
bias_data_tmp
.
data
(),
bias_data
,
bias_t
->
numel
()
*
sizeof
(
float
));
transpose_bias_v2
(
bias_data_tmp
.
data
(),
bias_data
,
head_number
,
head_size
);
nvinfer1
::
Weights
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
nvinfer1
::
Permutation
permutation
{
1
,
0
,
2
,
3
,
4
};
auto
trans_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
trans_layer
->
setFirstTranspose
(
permutation
);
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
trans_layer
->
getOutput
(
0
),
n
,
weight
,
bias
);
auto
pos_tensor
=
engine_
->
GetITensor
(
"eval_placeholder_2"
);
plugin
::
CastIntPluginDynamic
*
cast_plugin
=
new
plugin
::
CastIntPluginDynamic
();
auto
cast_layer
=
engine_
->
AddPluginV2
(
&
pos_tensor
,
1
,
cast_plugin
);
auto
casted_pos_tensor
=
cast_layer
->
getOutput
(
0
);
auto
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
casted_pos_tensor
);
nvinfer1
::
Dims2
reshape_dim
(
0
,
0
);
nvinfer1
::
Permutation
perm
{
1
,
0
,
2
};
reshape_layer
->
setFirstTranspose
(
perm
);
reshape_layer
->
setReshapeDimensions
(
reshape_dim
);
auto
reduce_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
reshape_layer
->
getOutput
(
0
),
nvinfer1
::
ReduceOperation
::
kMAX
,
1
,
false
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"1"
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
bool
has_mask
=
true
;
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
// no bool type
};
nvinfer1
::
PluginFieldCollection
*
pluginPtr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
pluginPtr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
pluginPtr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
pluginPtr
->
fields
=
fields
.
data
();
auto
pluginObj
=
creator
->
createPlugin
(
"CustomQKVToContextPluginDynamic"
,
pluginPtr
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
push_back
(
reduce_layer
->
getOutput
(
0
));
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
pluginObj
);
assert
(
plugin_layer
!=
nullptr
);
auto
trans_r_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
plugin_layer
->
getOutput
(
0
));
assert
(
trans_r_layer
!=
nullptr
);
trans_r_layer
->
setFirstTranspose
(
permutation
);
layer
=
trans_r_layer
;
#else
// transpose weight_data from m * n to n * m
tranpose_weight
(
weight_data_tmp
.
data
(),
weight_data
,
m
,
n
);
auto
*
input_bias_qk
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"BiasQK"
).
front
());
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
size_t
>
(
weight_t
->
numel
())};
weight
.
dims
.
assign
({
n
,
m
});
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
size_t
>
(
bias_t
->
numel
())};
...
...
@@ -81,20 +191,18 @@ class MultiheadMatMulOpConverter : public OpConverter {
weight
.
get
(),
bias
.
get
());
auto
*
fc_out
=
fc_layer
->
getOutput
(
0
);
// add qkv to context
int
head_number
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"head_number"
));
int
head_size
=
all_head_size
/
head_number
;
float
scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"alpha"
));
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_out
);
plugin_inputs
.
push_back
(
input_bias_qk
);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden
,
head_number
,
head_size
,
scale
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
2
,
plugin
);
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static shape mode, which "
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
2ca3fe5d
...
...
@@ -173,6 +173,8 @@ class OpConverter {
"optim_input_shape should be same."
));
}
}
std
::
cerr
<<
"Declare input: "
<<
input
<<
std
::
endl
;
if
(
input
.
find
(
"stack_0.tmp_0"
)
!=
std
::
string
::
npos
)
continue
;
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
2ca3fe5d
...
...
@@ -23,8 +23,9 @@ class SliceOpConverter : public OpConverter {
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
// This OP is implemented by trt dynamic shpae plugin.
// Dynamic shape plugin requires TRT version greater than 6.0.
// This OP is implemented by trt dynamic shpae plugin.
// Dynamic shape plugin requires TRT version greater than 6.0.
std
::
cerr
<<
"slice op converter
\n
"
<<
std
::
endl
;
#if IS_TRT_VERSION_GE(6000)
VLOG
(
4
)
<<
"convert slice op to tensorrt layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
...
...
@@ -42,7 +43,7 @@ class SliceOpConverter : public OpConverter {
if
(
engine_
->
with_dynamic_shape
())
{
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
end
s
,
ban_fp16
);
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axe
s
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
&
input
,
1
,
plugin
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
2ca3fe5d
...
...
@@ -183,6 +183,8 @@ class TRTConvertValidation {
std
::
vector
<
void
*>
buffers
(
num_bindings
);
for
(
const
std
::
string
&
name
:
input_output_names
)
{
// std::cerr << "Binding name: " << name << std::endl;
if
(
name
.
find
(
"stack_0.tmp_0"
)
!=
std
::
string
::
npos
)
continue
;
auto
*
var
=
scope_
.
FindVar
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
const
int
bind_index
=
engine_
->
engine
()
->
getBindingIndex
(
name
.
c_str
());
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
2ca3fe5d
...
...
@@ -71,6 +71,7 @@ void TensorRTEngine::FreezeNetwork() {
// build engine.
infer_builder_
->
setMaxBatchSize
(
max_batch_
);
infer_builder_
->
setMaxWorkspaceSize
(
max_workspace_
);
infer_builder_config_
->
setMaxWorkspaceSize
(
max_workspace_
);
bool
enable_fp16
=
(
precision_
==
AnalysisConfig
::
Precision
::
kHalf
);
#if IS_TRT_VERSION_GE(5000)
if
(
enable_fp16
)
{
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
2ca3fe5d
...
...
@@ -85,6 +85,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"gelu"
,
"layer_norm"
,
"scale"
,
"slice"
,
};
};
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
2ca3fe5d
...
...
@@ -2,6 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
cast_int_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu
0 → 100644
浏览文件 @
2ca3fe5d
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
nvinfer1
::
DimsExprs
CastIntPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
assert
(
output_index
==
0
);
return
inputs
[
0
];
}
bool
CastIntPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
{
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
return
(
in
.
type
==
nvinfer1
::
DataType
::
kINT32
);
}
nvinfer1
::
DataType
CastIntPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The Cast Int only has one input, so the "
"index value should be 0, but get %d."
,
index
));
return
input_types
[
index
];
}
__global__
void
castIntKernel
(
const
int64_t
*
input
,
int32_t
*
output
,
size_t
num_elements
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
>=
num_elements
)
return
;
output
[
idx
]
=
input
[
idx
]
+
1
;
}
int
CastIntPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
output_dims
=
output_desc
[
0
].
dims
;
size_t
num_elements
=
ProductDim
(
input_dims
);
size_t
out_num_elements
=
ProductDim
(
output_dims
);
assert
(
input_type
==
nvinfer1
::
DataType
::
kINT32
);
// although the input is int64_t
assert
(
num_elements
==
out_num_elements
);
const
size_t
num_threads
=
256
;
castIntKernel
<<<
num_elements
/
num_threads
+
1
,
num_threads
>>>
(
static_cast
<
const
int64_t
*>
(
inputs
[
0
]),
static_cast
<
int32_t
*>
(
outputs
[
0
]),
num_elements
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h
0 → 100644
浏览文件 @
2ca3fe5d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
CastIntPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
CastIntPluginDynamic
()
{}
CastIntPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
~
CastIntPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
CastIntPluginDynamic
();
}
const
char
*
getPluginType
()
const
override
{
return
"cast_int_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
{
return
0
;
}
size_t
getSerializationSize
()
const
override
{
return
0
;
}
void
serialize
(
void
*
buffer
)
const
override
{}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nb_outputs
)
const
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
override
;
void
destroy
()
override
{
delete
this
;
}
};
class
CastIntPluginV2Creator
:
public
nvinfer1
::
IPluginCreator
{
public:
CastIntPluginV2Creator
()
{}
const
char
*
getPluginName
()
const
override
{
return
"cast_int_plugin"
;
}
const
char
*
getPluginVersion
()
const
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
{
auto
plugin
=
new
CastIntPluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
CastIntPluginV2Creator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
2ca3fe5d
...
...
@@ -221,11 +221,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
num_inputs
+=
1
;
}
const
int
num_bindings
=
num_inputs
+
Outputs
(
"Ys"
).
size
();
// std::cerr << "num bindings: " << num_bindings << std::endl;
std
::
vector
<
void
*>
buffers
(
num_bindings
);
// Bind input tensor to TRT.
for
(
const
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
// std::cerr << "runTRT name: " << x << std::endl;
if
(
x
.
find
(
"stack_0.tmp_0"
)
!=
std
::
string
::
npos
)
continue
;
// convert input and copy to TRT engine's buffer
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录