Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fb3c0b2f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fb3c0b2f
编写于
9月 17, 2020
作者:
Z
zlsh80826
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
var len experimental test
上级
68e89f8a
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
526 addition
and
176 deletion
+526
-176
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+61
-12
paddle/fluid/inference/tensorrt/convert/mul_op.cc
paddle/fluid/inference/tensorrt/convert/mul_op.cc
+32
-27
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+9
-3
paddle/fluid/inference/tensorrt/convert/scale_op.cc
paddle/fluid/inference/tensorrt/convert/scale_op.cc
+88
-79
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+21
-15
paddle/fluid/inference/tensorrt/convert/stack_op.cc
paddle/fluid/inference/tensorrt/convert/stack_op.cc
+38
-35
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+2
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
...e/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
+174
-0
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h
...le/fluid/inference/tensorrt/plugin/special_slice_plugin.h
+96
-0
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+3
-1
未找到文件。
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
fb3c0b2f
...
@@ -40,6 +40,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -40,6 +40,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_ids
.
push_back
(
engine_
->
GetITensor
(
id_names
[
i
]));
input_ids
.
push_back
(
engine_
->
GetITensor
(
id_names
[
i
]));
}
}
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
// input_embs[2]: sent_embedding
std
::
vector
<
float
*>
input_embs
;
std
::
vector
<
float
*>
input_embs
;
std
::
vector
<
int
>
emb_sizes
;
std
::
vector
<
int
>
emb_sizes
;
...
@@ -54,7 +57,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -54,7 +57,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
return
temp_data
;
return
temp_data
;
};
};
int
hidden
=
0
;
//
int hidden = 0;
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
framework
::
DDim
emb_dims
;
framework
::
DDim
emb_dims
;
float
*
emb_data
=
get_persistable_data
(
emb_names
[
i
],
&
emb_dims
);
float
*
emb_data
=
get_persistable_data
(
emb_names
[
i
],
&
emb_dims
);
...
@@ -65,7 +68,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -65,7 +68,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
emb_dims
.
size
(),
2
,
emb_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The fused EmbEltwiseLayerNorm's emb should be 2 dims."
));
"The fused EmbEltwiseLayerNorm's emb should be 2 dims."
));
hidden
=
emb_dims
[
1
];
//
hidden = emb_dims[1];
}
}
framework
::
DDim
bias_dims
,
scale_dims
;
framework
::
DDim
bias_dims
,
scale_dims
;
...
@@ -76,21 +79,66 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -76,21 +79,66 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
get_persistable_data
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
get_persistable_data
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
int64_t
bias_size
=
framework
::
product
(
bias_dims
);
int64_t
bias_size
=
framework
::
product
(
bias_dims
);
int64_t
scale_size
=
framework
::
product
(
scale_dims
);
int64_t
scale_size
=
framework
::
product
(
scale_dims
);
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
int
output_fp16
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
1
:
0
);
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
#ifdef USE_NVINFER_PLUGIN
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"bert_embeddings_layernorm_beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
0
])},
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
2
])},
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
1
])},
{
"output_fp16"
,
&
output_fp16
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
};
// remember to free
nvinfer1
::
PluginFieldCollection
*
plugin_ptr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_ptr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
plugin_ptr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_ptr
->
fields
=
fields
.
data
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_0"
));
// word_embedding, eval_placeholder_0
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_1"
));
// sent_embedding, eval_placeholder_1
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_2"
));
// cu_seqlens, eval_placeholder_2
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_3"
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomEmbLayerNormPluginDynamic"
,
"2"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"CustomEmbLayerNormPluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
layer
=
plugin_layer
;
free
(
plugin_ptr
);
#else
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
=
new
plugin
::
Emb
EltwiseLayernormPluginDynamic
<
float
>
(
plugin
=
new
plugin
::
Emb
LaynormPluginDynamicV2
<
float
>
(
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
eps
);
eps
);
auto
plugin_layer
=
layer
=
engine_
->
AddPluginV2
(
input_ids
.
data
(),
input_num
,
plugin
);
engine_
->
AddPluginV2
(
input_ids
.
data
(),
input_num
,
plugin
);
#endif
nvinfer1
::
Permutation
permutation
{
1
,
0
,
2
,
3
,
4
};
auto
trans_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
plugin_layer
->
getOutput
(
0
));
trans_layer
->
setFirstTranspose
(
permutation
);
layer
=
trans_layer
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static"
"You are running the Ernie(Bert) model in static"
...
@@ -100,7 +148,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -100,7 +148,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
}
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
},
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
)},
test_mode
);
test_mode
);
#else
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
...
...
paddle/fluid/inference/tensorrt/convert/mul_op.cc
浏览文件 @
fb3c0b2f
...
@@ -26,33 +26,38 @@ class MulOpConverter : public OpConverter {
...
@@ -26,33 +26,38 @@ class MulOpConverter : public OpConverter {
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid mul op to tensorrt mul layer without bias"
;
/*
VLOG(3) << "convert a fluid mul op to tensorrt mul layer without bias";
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
framework::OpDesc op_desc(op, nullptr);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
// Declare inputs
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
bool
transpose_x
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_X"
));
bool
transpose_y
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_Y"
));
bool transpose_x = BOOST_GET_CONST(bool,
op_desc.GetAttr("transpose_X"));
#ifdef USE_NVINFER_PLUGIN
bool transpose_y = BOOST_GET_CONST(bool,
nvinfer1
::
DataType
type
=
(
engine_
->
WithFp16
()
==
1
)
op_desc.GetAttr("transpose_Y"));
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
;
#ifdef USE_NVINFER_PLUGIN
plugin
::
ConvertMaskPluginDynamic
*
plugin
=
nvinfer1::DataType type = (engine_->WithFp16() == 1)
new
plugin
::
ConvertMaskPluginDynamic
(
type
);
? nvinfer1::DataType::kHALF
auto
convert_mask_layer
=
engine_
->
AddPluginV2
(
&
input1
,
1
,
plugin
);
: nvinfer1::DataType::kFLOAT;
engine_
->
SetITensor
(
"qkv_plugin_mask"
,
convert_mask_layer
->
getOutput
(
0
));
plugin::ConvertMaskPluginDynamic* plugin =
#endif
new plugin::ConvertMaskPluginDynamic(type);
auto convert_mask_layer = engine_->AddPluginV2(&input1, 1, plugin);
// Both the input1 and input2 do not need transpose.
engine_->SetITensor("qkv_plugin_mask",
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
convert_mask_layer->getOutput(0));
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
#endif
transpose_x
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
transpose_y
);
// Both the input1 and input2 do not need transpose.
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto* layer = TRT_ENGINE_ADD_LAYER(
RreplenishLayerAndOutput
(
layer
,
"matmul"
,
{
output_name
},
test_mode
);
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1),
transpose_x, *const_cast<nvinfer1::ITensor*>(input2), transpose_y);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "matmul", {output_name}, test_mode);
*/
}
}
};
};
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
fb3c0b2f
...
@@ -119,17 +119,19 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -119,17 +119,19 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
auto
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"
1
"
);
"CustomQKVToContextPluginDynamic"
,
"
2
"
);
assert
(
creator
!=
nullptr
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
:
nvinfer1
::
DataType
::
kFLOAT
);
bool
has_mask
=
true
;
bool
has_mask
=
true
;
int
var_seqlen
=
1
;
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
};
};
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
...
@@ -144,8 +146,12 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -144,8 +146,12 @@ class MultiheadMatMulOpConverter : public OpConverter {
free
(
plugin_collection
);
free
(
plugin_collection
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
push_back
(
mask_tensor
);
plugin_inputs
.
emplace_back
(
mask_tensor
);
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_2"
));
// cu_seqlens, eval_placeholder_2
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_3"
));
// max_seqlen, eval_placeholder_3
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
layer
=
plugin_layer
;
layer
=
plugin_layer
;
...
...
paddle/fluid/inference/tensorrt/convert/scale_op.cc
浏览文件 @
fb3c0b2f
...
@@ -25,85 +25,94 @@ class ScaleOpConverter : public OpConverter {
...
@@ -25,85 +25,94 @@ class ScaleOpConverter : public OpConverter {
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid scale op to tensorrt mul layer without bias"
;
/*
VLOG(3) << "convert a fluid scale op to tensorrt mul layer without
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
bias";
// Declare inputs
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
;
framework::OpDesc op_desc(op, nullptr);
std
::
string
input_name
=
op_desc
.
Input
(
"X"
).
front
();
// Declare inputs
std
::
string
out_name
=
op_desc
.
Output
(
"Out"
).
front
();
std::vector<nvinfer1::ITensor*> itensors;
std::string input_name = op_desc.Input("X").front();
auto
input
=
engine_
->
GetITensor
(
input_name
);
std::string out_name = op_desc.Output("Out").front();
bool
bias_after_scale
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"bias_after_scale"
));
auto input = engine_->GetITensor(input_name);
float
bias
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"bias"
));
bool bias_after_scale =
float
scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"scale"
));
BOOST_GET_CONST(bool, op_desc.GetAttr("bias_after_scale"));
auto
create_weights
=
[
&
](
float
data
,
std
::
string
type
)
->
float
*
{
float bias = BOOST_GET_CONST(float, op_desc.GetAttr("bias"));
std
::
unique_ptr
<
framework
::
Tensor
>
tmp_tensor
(
new
framework
::
Tensor
());
float scale = BOOST_GET_CONST(float, op_desc.GetAttr("scale"));
tmp_tensor
->
Resize
({
1
});
auto create_weights = [&](float data, std::string type) -> float* {
auto
*
tmp_data
=
tmp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
std::unique_ptr<framework::Tensor> tmp_tensor(new
tmp_data
[
0
]
=
data
;
framework::Tensor());
engine_
->
SetWeights
(
out_name
+
"_scale_op_"
+
type
,
tmp_tensor->Resize({1});
std
::
move
(
tmp_tensor
));
auto* tmp_data =
return
tmp_data
;
tmp_tensor->mutable_data<float>(platform::CPUPlace());
};
tmp_data[0] = data;
engine_->SetWeights(out_name + "_scale_op_" + type,
float
*
bias_ptr
=
create_weights
(
bias
,
"bias"
);
std::move(tmp_tensor));
float
*
scale_ptr
=
create_weights
(
scale
,
"scale"
);
return tmp_data;
};
TensorRTEngine
::
Weight
scale_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
scale_ptr
),
1
};
float* bias_ptr = create_weights(bias, "bias");
TensorRTEngine
::
Weight
shift_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
float* scale_ptr = create_weights(scale, "scale");
static_cast
<
void
*>
(
bias_ptr
),
1
};
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
0
};
static_cast<void*>(scale_ptr), 1};
nvinfer1
::
ILayer
*
layer
=
nullptr
;
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_ptr), 1};
auto
input_dim
=
input
->
getDimensions
();
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT,
PADDLE_ENFORCE_GE
(
input_dim
.
nbDims
,
3
,
nullptr,
platform
::
errors
::
Fatal
(
0};
"Paddle-TRT scale mode only support dimension >= 3"
));
nvinfer1::ILayer* layer = nullptr;
nvinfer1
::
IShuffleLayer
*
expand_layer
=
nullptr
;
auto input_dim = input->getDimensions();
nvinfer1
::
IShuffleLayer
*
squeeze_layer
=
nullptr
;
PADDLE_ENFORCE_GE(input_dim.nbDims, 3,
platform::errors::Fatal(
if
(
input_dim
.
nbDims
==
3
)
{
"Paddle-TRT scale mode only support dimension >=
// TensorRT scale layer is not supporting input dims < 4 when using
3"));
// explicit batch
expand_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1
::
Dims4
target_shape
(
0
,
0
,
0
,
1
);
// expand 1 dims
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
expand_layer
->
setReshapeDimensions
(
target_shape
);
input
=
expand_layer
->
getOutput
(
0
);
if (input_dim.nbDims == 3) {
}
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
if
(
bias_after_scale
)
{
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer
=
TRT_ENGINE_ADD_LAYER
(
nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims
engine_
,
Scale
,
*
input
,
nvinfer1
::
ScaleMode
::
kUNIFORM
,
expand_layer->setReshapeDimensions(target_shape);
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
input = expand_layer->getOutput(0);
}
else
{
}
// add bias
layer
=
TRT_ENGINE_ADD_LAYER
(
if (bias_after_scale) {
engine_
,
Scale
,
*
(
input
),
nvinfer1
::
ScaleMode
::
kUNIFORM
,
layer = TRT_ENGINE_ADD_LAYER(
shift_weights
.
get
(),
power_weights
.
get
(),
power_weights
.
get
());
engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM,
// mul scale
shift_weights.get(), scale_weights.get(), power_weights.get());
layer
=
TRT_ENGINE_ADD_LAYER
(
} else {
engine_
,
Scale
,
*
(
layer
->
getOutput
(
0
)),
nvinfer1
::
ScaleMode
::
kUNIFORM
,
// add bias
power_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
layer = TRT_ENGINE_ADD_LAYER(
}
engine_, Scale, *(input), nvinfer1::ScaleMode::kUNIFORM,
shift_weights.get(), power_weights.get(), power_weights.get());
PADDLE_ENFORCE_EQ
(
layer
!=
nullptr
,
true
,
// mul scale
platform
::
errors
::
Fatal
(
"Create scale layer failed."
));
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *(layer->getOutput(0)),
if
(
input_dim
.
nbDims
==
3
)
{
nvinfer1::ScaleMode::kUNIFORM,
// TensorRT scale layer is not supporting input dims < 4 when using
power_weights.get(), scale_weights.get(), power_weights.get());
// explicit batch
}
squeeze_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
(
layer
->
getOutput
(
0
)));
PADDLE_ENFORCE_EQ(layer != nullptr, true,
nvinfer1
::
Dims3
target_shape
(
0
,
0
,
0
);
// expand 1 dims
platform::errors::Fatal("Create scale layer
squeeze_layer
->
setReshapeDimensions
(
target_shape
);
failed."));
layer
=
static_cast
<
nvinfer1
::
ILayer
*>
(
squeeze_layer
);
}
if (input_dim.nbDims == 3) {
RreplenishLayerAndOutput
(
layer
,
"scale"
,
{
out_name
},
test_mode
);
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims
squeeze_layer->setReshapeDimensions(target_shape);
layer = static_cast<nvinfer1::ILayer*>(squeeze_layer);
}
RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode);
*/
}
}
};
};
...
...
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
fb3c0b2f
...
@@ -54,7 +54,7 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -54,7 +54,7 @@ class SkipLayerNormOpConverter : public OpConverter {
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
#ifdef USE_NVINFER_PLUGIN
#ifdef USE_NVINFER_PLUGIN
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"
1
"
);
"CustomSkipLayerNormPluginDynamic"
,
"
2
"
);
assert
(
creator
!=
nullptr
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
?
nvinfer1
::
DataType
::
kHALF
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
fb3c0b2f
...
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
// #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -30,25 +31,30 @@ class SliceOpConverter : public OpConverter {
...
@@ -30,25 +31,30 @@ class SliceOpConverter : public OpConverter {
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Input"
)[
0
]);
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Input"
)[
0
]);
/*
std
::
vector
<
int
>
axes
=
std::vector<int> axes =
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"axes"
));
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
std
::
vector
<
int
>
starts
=
std::vector<int> starts =
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"starts"
));
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("starts"));
std
::
vector
<
int
>
ends
=
std::vector<int> ends =
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"ends"
));
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
*/
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
nvinfer1
::
Permutation
permutation
{
1
,
0
,
2
,
3
,
4
};
/*
auto
trans_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
nvinfer1::Permutation permutation{1, 0, 2, 3, 4};
trans_layer
->
setFirstTranspose
(
permutation
);
auto trans_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
trans_layer->setFirstTranspose(permutation);
*/
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
trans_layer
->
getOutput
(
0
));
// plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs
.
emplace_back
(
input
);
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_2"
));
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
//
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin
::
SlicePluginDynamic
*
plugin
=
plugin
::
S
pecialS
licePluginDynamic
*
plugin
=
new
plugin
::
S
licePluginDynamic
(
starts
,
ends
,
axes
,
ban_fp16
);
new
plugin
::
S
pecialSlicePluginDynamic
(
);
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
plugin
);
plugin
);
}
else
{
}
else
{
...
...
paddle/fluid/inference/tensorrt/convert/stack_op.cc
浏览文件 @
fb3c0b2f
...
@@ -26,45 +26,48 @@ class StackOpConverter : public OpConverter {
...
@@ -26,45 +26,48 @@ class StackOpConverter : public OpConverter {
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fluid stack op to tensorrt stack layer"
;
/*
VLOG(4) << "convert fluid stack op to tensorrt stack layer";
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework::OpDesc op_desc(op, nullptr);
auto
input
=
op_desc
.
Input
(
"X"
);
auto input = op_desc.Input("X");
int
input_num
=
input
.
size
();
int input_num = input.size();
nvinfer1
::
ITensor
**
inputs
=
nvinfer1::ITensor** inputs =
(
nvinfer1
::
ITensor
**
)
malloc
(
input_num
*
sizeof
(
nvinfer1
::
ITensor
*
));
(nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*));
for
(
int
i
=
0
;
i
<
input_num
;
++
i
)
{
for (int i = 0; i < input_num; ++i) {
inputs
[
i
]
=
engine_
->
GetITensor
(
input
[
i
]);
inputs[i] = engine_->GetITensor(input[i]);
}
}
int
axis
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"axis"
));
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
if
(
axis
<
0
)
{
if (axis < 0) {
axis
=
axis
+
inputs
[
0
]
->
getDimensions
().
nbDims
+
1
;
axis = axis + inputs[0]->getDimensions().nbDims + 1;
}
}
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1::ILayer* layer = nullptr;
if
(
engine_
->
with_dynamic_shape
())
{
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
plugin
::
StackPluginDynamic
*
plugin
=
plugin::StackPluginDynamic* plugin =
new
plugin
::
StackPluginDynamic
(
axis
,
input_num
);
new plugin::StackPluginDynamic(axis, input_num);
layer
=
engine_
->
AddPluginV2
(
inputs
,
input_num
,
plugin
);
layer = engine_->AddPluginV2(inputs, input_num, plugin);
assert
(
layer
!=
nullptr
);
assert(layer != nullptr);
#else
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"You are running the TRT Dynamic Shape mode, need to confirm that
"your TRT version is no less than 6.0"
));
"
#endif
"your TRT version is no less than 6.0"));
}
else
{
#endif
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
} else {
"You are running the Ernie(Bert) model in static"
PADDLE_THROW(platform::errors::Fatal(
"shape mode, which is not supported for the time being.
\n
"
"You are running the Ernie(Bert) model in static"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
"shape mode, which is not supported for the time being.\n"
" to set the shape information to run the dynamic shape mode."
));
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
}
" to set the shape information to run the dynamic shape mode."));
auto
output_name
=
op_desc
.
Output
(
"Y"
).
front
();
}
RreplenishLayerAndOutput
(
layer
,
"stack"
,
{
output_name
},
test_mode
);
auto output_name = op_desc.Output("Y").front();
free
(
inputs
);
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
free(inputs);
*/
}
}
};
};
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
fb3c0b2f
...
@@ -60,9 +60,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
...
@@ -60,9 +60,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
template
<
typename
T
>
template
<
typename
T
>
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
T
>&
shape
,
std
::
string
input
,
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
T
>&
shape
,
std
::
string
input
,
bool
with_dynamic_shape
=
false
)
{
bool
with_dynamic_shape
=
false
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1
UL
,
PADDLE_ENFORCE_GT
(
shape
.
size
(),
0
UL
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"TensorRT's tensor input requires at least
2
"
"TensorRT's tensor input requires at least
1
"
"dimensions, but input %s has %d dims."
,
"dimensions, but input %s has %d dims."
,
input
,
shape
.
size
()));
input
,
shape
.
size
()));
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
fb3c0b2f
...
@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
...
@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
stack_op_plugin.cu convert_mask_plugin.cu
stack_op_plugin.cu convert_mask_plugin.cu
special_slice_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
0 → 100644
浏览文件 @
fb3c0b2f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
()
{}
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
SpecialSlicePluginDynamic
::~
SpecialSlicePluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
SpecialSlicePluginDynamic
::
clone
()
const
{
return
new
SpecialSlicePluginDynamic
();
}
const
char
*
SpecialSlicePluginDynamic
::
getPluginType
()
const
{
return
"special_slice_plugin"
;
}
int
SpecialSlicePluginDynamic
::
getNbOutputs
()
const
{
return
1
;
}
int
SpecialSlicePluginDynamic
::
initialize
()
{
return
0
;
}
size_t
SpecialSlicePluginDynamic
::
getSerializationSize
()
const
{
size_t
serialize_size
=
0
;
return
serialize_size
;
}
void
SpecialSlicePluginDynamic
::
serialize
(
void
*
buffer
)
const
{}
nvinfer1
::
DimsExprs
SpecialSlicePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
auto
one
=
expr_builder
.
constant
(
1
);
output
.
d
[
0
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
inputs
[
1
].
d
[
0
],
*
one
);
return
output
;
}
void
SpecialSlicePluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
{}
size_t
SpecialSlicePluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
{
return
0
;
}
void
SpecialSlicePluginDynamic
::
destroy
()
{
delete
this
;
}
void
SpecialSlicePluginDynamic
::
terminate
()
{}
bool
SpecialSlicePluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
desc
,
int
nb_inputs
,
int
nb_outputs
)
{
if
(
pos
==
0
)
// slice tensor
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
if
(
pos
==
1
)
// cu_seqlen
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
}
nvinfer1
::
DataType
SpecialSlicePluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The index should be equal to 0"
));
return
input_types
[
0
];
}
template
<
typename
T
>
__global__
void
SpecialSliceKernel
(
const
T
*
slice_input
,
const
int32_t
*
cu_seqlens
,
T
*
output
)
{
const
int
hidden
=
blockDim
.
x
;
const
int
batch
=
blockIdx
.
x
;
output
[
batch
*
hidden
+
threadIdx
.
x
]
=
slice_input
[
cu_seqlens
[
batch
]
*
hidden
+
threadIdx
.
x
];
}
int
SpecialSlicePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
// (sum(S), 768, 1, 1)
auto
out_dims
=
output_desc
[
0
].
dims
;
// (batch, 768, 1, 1)
assert
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
);
const
int32_t
hidden
=
input_dims
.
d
[
1
];
const
int
num_blocks
=
out_dims
.
d
[
0
];
// batch size
const
int
num_threads
=
hidden
;
const
half
*
slice_input
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
cu_seqlens
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SpecialSliceKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
slice_input
,
cu_seqlens
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
SpecialSlicePluginDynamicCreator
::
SpecialSlicePluginDynamicCreator
()
{}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginName
()
const
{
return
"stack_plugin"
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginVersion
()
const
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
SpecialSlicePluginDynamicCreator
::
getFieldNames
()
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
{
return
new
SpecialSlicePluginDynamic
();
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
{
auto
plugin
=
new
SpecialSlicePluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
SpecialSlicePluginDynamicCreator
::
setPluginNamespace
(
const
char
*
lib_namespace
)
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginNamespace
()
const
{
return
plugin_namespace_
.
c_str
();
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h
0 → 100644
浏览文件 @
fb3c0b2f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
SpecialSlicePluginDynamic
:
public
DynamicPluginTensorRT
{
public:
SpecialSlicePluginDynamic
();
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
);
~
SpecialSlicePluginDynamic
();
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
override
;
const
char
*
getPluginType
()
const
override
;
int
getNbOutputs
()
const
override
;
int
initialize
()
override
;
void
terminate
()
override
;
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
void
destroy
()
override
;
private:
int
axis_
;
int
num_stack_
;
};
class
SpecialSlicePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
SpecialSlicePluginDynamicCreator
();
const
char
*
getPluginName
()
const
override
;
const
char
*
getPluginVersion
()
const
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
;
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
;
const
char
*
getPluginNamespace
()
const
override
;
private:
std
::
string
plugin_namespace_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
SpecialSlicePluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
fb3c0b2f
...
@@ -264,9 +264,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -264,9 +264,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int64_t
>
());
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int64_t
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int32_t
>
());
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The TRT Engine OP only support float
and
int64_t input."
));
"The TRT Engine OP only support float
/int32_t/
int64_t input."
));
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录