Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
fb3c0b2f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fb3c0b2f
编写于
9月 17, 2020
作者:
Z
zlsh80826
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
var len experimental test
上级
68e89f8a
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
526 addition
and
176 deletion
+526
-176
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+61
-12
paddle/fluid/inference/tensorrt/convert/mul_op.cc
paddle/fluid/inference/tensorrt/convert/mul_op.cc
+32
-27
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+9
-3
paddle/fluid/inference/tensorrt/convert/scale_op.cc
paddle/fluid/inference/tensorrt/convert/scale_op.cc
+88
-79
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+21
-15
paddle/fluid/inference/tensorrt/convert/stack_op.cc
paddle/fluid/inference/tensorrt/convert/stack_op.cc
+38
-35
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+2
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
...e/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
+174
-0
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h
...le/fluid/inference/tensorrt/plugin/special_slice_plugin.h
+96
-0
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+3
-1
未找到文件。
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
fb3c0b2f
...
...
@@ -40,6 +40,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_ids
.
push_back
(
engine_
->
GetITensor
(
id_names
[
i
]));
}
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
// input_embs[2]: sent_embedding
std
::
vector
<
float
*>
input_embs
;
std
::
vector
<
int
>
emb_sizes
;
...
...
@@ -54,7 +57,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
return
temp_data
;
};
int
hidden
=
0
;
//
int hidden = 0;
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
framework
::
DDim
emb_dims
;
float
*
emb_data
=
get_persistable_data
(
emb_names
[
i
],
&
emb_dims
);
...
...
@@ -65,7 +68,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
emb_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The fused EmbEltwiseLayerNorm's emb should be 2 dims."
));
hidden
=
emb_dims
[
1
];
//
hidden = emb_dims[1];
}
framework
::
DDim
bias_dims
,
scale_dims
;
...
...
@@ -76,21 +79,66 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
get_persistable_data
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
int64_t
bias_size
=
framework
::
product
(
bias_dims
);
int64_t
scale_size
=
framework
::
product
(
scale_dims
);
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
nvinfer1
::
ILayer
*
layer
=
nullptr
;
int
output_fp16
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
1
:
0
);
if
(
engine_
->
with_dynamic_shape
())
{
#ifdef USE_NVINFER_PLUGIN
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"bert_embeddings_layernorm_beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
0
])},
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
2
])},
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
1
])},
{
"output_fp16"
,
&
output_fp16
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
};
// remember to free
nvinfer1
::
PluginFieldCollection
*
plugin_ptr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_ptr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
plugin_ptr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_ptr
->
fields
=
fields
.
data
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_0"
));
// word_embedding, eval_placeholder_0
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_1"
));
// sent_embedding, eval_placeholder_1
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_2"
));
// cu_seqlens, eval_placeholder_2
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_3"
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomEmbLayerNormPluginDynamic"
,
"2"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"CustomEmbLayerNormPluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
layer
=
plugin_layer
;
free
(
plugin_ptr
);
#else
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
=
new
plugin
::
Emb
EltwiseLayernormPluginDynamic
<
float
>
(
plugin
=
new
plugin
::
Emb
LaynormPluginDynamicV2
<
float
>
(
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
eps
);
auto
plugin_layer
=
engine_
->
AddPluginV2
(
input_ids
.
data
(),
input_num
,
plugin
);
nvinfer1
::
Permutation
permutation
{
1
,
0
,
2
,
3
,
4
};
auto
trans_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
plugin_layer
->
getOutput
(
0
));
trans_layer
->
setFirstTranspose
(
permutation
);
layer
=
trans_layer
;
layer
=
engine_
->
AddPluginV2
(
input_ids
.
data
(),
input_num
,
plugin
);
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static"
...
...
@@ -100,7 +148,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
},
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
)},
test_mode
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
...
...
paddle/fluid/inference/tensorrt/convert/mul_op.cc
浏览文件 @
fb3c0b2f
...
...
@@ -26,33 +26,38 @@ class MulOpConverter : public OpConverter {
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid mul op to tensorrt mul layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
bool
transpose_x
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_X"
));
bool
transpose_y
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_Y"
));
#ifdef USE_NVINFER_PLUGIN
nvinfer1
::
DataType
type
=
(
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
;
plugin
::
ConvertMaskPluginDynamic
*
plugin
=
new
plugin
::
ConvertMaskPluginDynamic
(
type
);
auto
convert_mask_layer
=
engine_
->
AddPluginV2
(
&
input1
,
1
,
plugin
);
engine_
->
SetITensor
(
"qkv_plugin_mask"
,
convert_mask_layer
->
getOutput
(
0
));
#endif
// Both the input1 and input2 do not need transpose.
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
transpose_x
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
transpose_y
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"matmul"
,
{
output_name
},
test_mode
);
/*
VLOG(3) << "convert a fluid mul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
bool transpose_x = BOOST_GET_CONST(bool,
op_desc.GetAttr("transpose_X"));
bool transpose_y = BOOST_GET_CONST(bool,
op_desc.GetAttr("transpose_Y"));
#ifdef USE_NVINFER_PLUGIN
nvinfer1::DataType type = (engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT;
plugin::ConvertMaskPluginDynamic* plugin =
new plugin::ConvertMaskPluginDynamic(type);
auto convert_mask_layer = engine_->AddPluginV2(&input1, 1, plugin);
engine_->SetITensor("qkv_plugin_mask",
convert_mask_layer->getOutput(0));
#endif
// Both the input1 and input2 do not need transpose.
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1),
transpose_x, *const_cast<nvinfer1::ITensor*>(input2), transpose_y);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "matmul", {output_name}, test_mode);
*/
}
};
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
fb3c0b2f
...
...
@@ -119,17 +119,19 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"
1
"
);
"CustomQKVToContextPluginDynamic"
,
"
2
"
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
bool
has_mask
=
true
;
int
var_seqlen
=
1
;
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
};
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
...
...
@@ -144,8 +146,12 @@ class MultiheadMatMulOpConverter : public OpConverter {
free
(
plugin_collection
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
push_back
(
mask_tensor
);
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
emplace_back
(
mask_tensor
);
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_2"
));
// cu_seqlens, eval_placeholder_2
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_3"
));
// max_seqlen, eval_placeholder_3
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
layer
=
plugin_layer
;
...
...
paddle/fluid/inference/tensorrt/convert/scale_op.cc
浏览文件 @
fb3c0b2f
...
...
@@ -25,85 +25,94 @@ class ScaleOpConverter : public OpConverter {
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid scale op to tensorrt mul layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
;
std
::
string
input_name
=
op_desc
.
Input
(
"X"
).
front
();
std
::
string
out_name
=
op_desc
.
Output
(
"Out"
).
front
();
auto
input
=
engine_
->
GetITensor
(
input_name
);
bool
bias_after_scale
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"bias_after_scale"
));
float
bias
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"bias"
));
float
scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"scale"
));
auto
create_weights
=
[
&
](
float
data
,
std
::
string
type
)
->
float
*
{
std
::
unique_ptr
<
framework
::
Tensor
>
tmp_tensor
(
new
framework
::
Tensor
());
tmp_tensor
->
Resize
({
1
});
auto
*
tmp_data
=
tmp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
tmp_data
[
0
]
=
data
;
engine_
->
SetWeights
(
out_name
+
"_scale_op_"
+
type
,
std
::
move
(
tmp_tensor
));
return
tmp_data
;
};
float
*
bias_ptr
=
create_weights
(
bias
,
"bias"
);
float
*
scale_ptr
=
create_weights
(
scale
,
"scale"
);
TensorRTEngine
::
Weight
scale_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
scale_ptr
),
1
};
TensorRTEngine
::
Weight
shift_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_ptr
),
1
};
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
nvinfer1
::
ILayer
*
layer
=
nullptr
;
auto
input_dim
=
input
->
getDimensions
();
PADDLE_ENFORCE_GE
(
input_dim
.
nbDims
,
3
,
platform
::
errors
::
Fatal
(
"Paddle-TRT scale mode only support dimension >= 3"
));
nvinfer1
::
IShuffleLayer
*
expand_layer
=
nullptr
;
nvinfer1
::
IShuffleLayer
*
squeeze_layer
=
nullptr
;
if
(
input_dim
.
nbDims
==
3
)
{
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
expand_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
nvinfer1
::
Dims4
target_shape
(
0
,
0
,
0
,
1
);
// expand 1 dims
expand_layer
->
setReshapeDimensions
(
target_shape
);
input
=
expand_layer
->
getOutput
(
0
);
}
if
(
bias_after_scale
)
{
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
input
,
nvinfer1
::
ScaleMode
::
kUNIFORM
,
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
}
else
{
// add bias
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
(
input
),
nvinfer1
::
ScaleMode
::
kUNIFORM
,
shift_weights
.
get
(),
power_weights
.
get
(),
power_weights
.
get
());
// mul scale
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
(
layer
->
getOutput
(
0
)),
nvinfer1
::
ScaleMode
::
kUNIFORM
,
power_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
}
PADDLE_ENFORCE_EQ
(
layer
!=
nullptr
,
true
,
platform
::
errors
::
Fatal
(
"Create scale layer failed."
));
if
(
input_dim
.
nbDims
==
3
)
{
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
squeeze_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
(
layer
->
getOutput
(
0
)));
nvinfer1
::
Dims3
target_shape
(
0
,
0
,
0
);
// expand 1 dims
squeeze_layer
->
setReshapeDimensions
(
target_shape
);
layer
=
static_cast
<
nvinfer1
::
ILayer
*>
(
squeeze_layer
);
}
RreplenishLayerAndOutput
(
layer
,
"scale"
,
{
out_name
},
test_mode
);
/*
VLOG(3) << "convert a fluid scale op to tensorrt mul layer without
bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
std::vector<nvinfer1::ITensor*> itensors;
std::string input_name = op_desc.Input("X").front();
std::string out_name = op_desc.Output("Out").front();
auto input = engine_->GetITensor(input_name);
bool bias_after_scale =
BOOST_GET_CONST(bool, op_desc.GetAttr("bias_after_scale"));
float bias = BOOST_GET_CONST(float, op_desc.GetAttr("bias"));
float scale = BOOST_GET_CONST(float, op_desc.GetAttr("scale"));
auto create_weights = [&](float data, std::string type) -> float* {
std::unique_ptr<framework::Tensor> tmp_tensor(new
framework::Tensor());
tmp_tensor->Resize({1});
auto* tmp_data =
tmp_tensor->mutable_data<float>(platform::CPUPlace());
tmp_data[0] = data;
engine_->SetWeights(out_name + "_scale_op_" + type,
std::move(tmp_tensor));
return tmp_data;
};
float* bias_ptr = create_weights(bias, "bias");
float* scale_ptr = create_weights(scale, "scale");
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(scale_ptr), 1};
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_ptr), 1};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT,
nullptr,
0};
nvinfer1::ILayer* layer = nullptr;
auto input_dim = input->getDimensions();
PADDLE_ENFORCE_GE(input_dim.nbDims, 3,
platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >=
3"));
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims
expand_layer->setReshapeDimensions(target_shape);
input = expand_layer->getOutput(0);
}
if (bias_after_scale) {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM,
shift_weights.get(), scale_weights.get(), power_weights.get());
} else {
// add bias
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *(input), nvinfer1::ScaleMode::kUNIFORM,
shift_weights.get(), power_weights.get(), power_weights.get());
// mul scale
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *(layer->getOutput(0)),
nvinfer1::ScaleMode::kUNIFORM,
power_weights.get(), scale_weights.get(), power_weights.get());
}
PADDLE_ENFORCE_EQ(layer != nullptr, true,
platform::errors::Fatal("Create scale layer
failed."));
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims
squeeze_layer->setReshapeDimensions(target_shape);
layer = static_cast<nvinfer1::ILayer*>(squeeze_layer);
}
RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode);
*/
}
};
...
...
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
fb3c0b2f
...
...
@@ -54,7 +54,7 @@ class SkipLayerNormOpConverter : public OpConverter {
if
(
engine_
->
with_dynamic_shape
())
{
#ifdef USE_NVINFER_PLUGIN
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"
1
"
);
"CustomSkipLayerNormPluginDynamic"
,
"
2
"
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
fb3c0b2f
...
...
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
// #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -30,25 +31,30 @@ class SliceOpConverter : public OpConverter {
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Input"
)[
0
]);
std
::
vector
<
int
>
axes
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"axes"
));
std
::
vector
<
int
>
starts
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"starts"
));
std
::
vector
<
int
>
ends
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"ends"
));
/*
std::vector<int> axes =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
std::vector<int> starts =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("starts"));
std::vector<int> ends =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
*/
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
nvinfer1
::
Permutation
permutation
{
1
,
0
,
2
,
3
,
4
};
auto
trans_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
trans_layer
->
setFirstTranspose
(
permutation
);
/*
nvinfer1::Permutation permutation{1, 0, 2, 3, 4};
auto trans_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
trans_layer->setFirstTranspose(permutation);
*/
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
trans_layer
->
getOutput
(
0
));
// plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs
.
emplace_back
(
input
);
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"eval_placeholder_2"
));
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
S
licePluginDynamic
(
starts
,
ends
,
axes
,
ban_fp16
);
//
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin
::
S
pecialS
licePluginDynamic
*
plugin
=
new
plugin
::
S
pecialSlicePluginDynamic
(
);
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
plugin
);
}
else
{
...
...
paddle/fluid/inference/tensorrt/convert/stack_op.cc
浏览文件 @
fb3c0b2f
...
...
@@ -26,45 +26,48 @@ class StackOpConverter : public OpConverter {
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fluid stack op to tensorrt stack layer"
;
/*
VLOG(4) << "convert fluid stack op to tensorrt stack layer";
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
input
=
op_desc
.
Input
(
"X"
);
int
input_num
=
input
.
size
();
nvinfer1
::
ITensor
**
inputs
=
(
nvinfer1
::
ITensor
**
)
malloc
(
input_num
*
sizeof
(
nvinfer1
::
ITensor
*
));
framework::OpDesc op_desc(op, nullptr);
auto input = op_desc.Input("X");
int input_num = input.size();
nvinfer1::ITensor** inputs =
(nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*));
for
(
int
i
=
0
;
i
<
input_num
;
++
i
)
{
inputs
[
i
]
=
engine_
->
GetITensor
(
input
[
i
]);
}
for (int i = 0; i < input_num; ++i) {
inputs[i] = engine_->GetITensor(input[i]);
}
int
axis
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"axis"
));
if
(
axis
<
0
)
{
axis
=
axis
+
inputs
[
0
]
->
getDimensions
().
nbDims
+
1
;
}
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis < 0) {
axis = axis + inputs[0]->getDimensions().nbDims + 1;
}
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
#if IS_TRT_VERSION_GE(6000)
plugin
::
StackPluginDynamic
*
plugin
=
new
plugin
::
StackPluginDynamic
(
axis
,
input_num
);
layer
=
engine_
->
AddPluginV2
(
inputs
,
input_num
,
plugin
);
assert
(
layer
!=
nullptr
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.
\n
"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."
));
}
auto
output_name
=
op_desc
.
Output
(
"Y"
).
front
();
RreplenishLayerAndOutput
(
layer
,
"stack"
,
{
output_name
},
test_mode
);
free
(
inputs
);
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::StackPluginDynamic* plugin =
new plugin::StackPluginDynamic(axis, input_num);
layer = engine_->AddPluginV2(inputs, input_num, plugin);
assert(layer != nullptr);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that
"
"your TRT version is no less than 6.0"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
free(inputs);
*/
}
};
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
fb3c0b2f
...
...
@@ -60,9 +60,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
template
<
typename
T
>
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
T
>&
shape
,
std
::
string
input
,
bool
with_dynamic_shape
=
false
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1
UL
,
PADDLE_ENFORCE_GT
(
shape
.
size
(),
0
UL
,
platform
::
errors
::
InvalidArgument
(
"TensorRT's tensor input requires at least
2
"
"TensorRT's tensor input requires at least
1
"
"dimensions, but input %s has %d dims."
,
input
,
shape
.
size
()));
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
fb3c0b2f
...
...
@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
stack_op_plugin.cu convert_mask_plugin.cu
stack_op_plugin.cu convert_mask_plugin.cu
special_slice_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
0 → 100644
浏览文件 @
fb3c0b2f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
()
{}
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
SpecialSlicePluginDynamic
::~
SpecialSlicePluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
SpecialSlicePluginDynamic
::
clone
()
const
{
return
new
SpecialSlicePluginDynamic
();
}
const
char
*
SpecialSlicePluginDynamic
::
getPluginType
()
const
{
return
"special_slice_plugin"
;
}
int
SpecialSlicePluginDynamic
::
getNbOutputs
()
const
{
return
1
;
}
int
SpecialSlicePluginDynamic
::
initialize
()
{
return
0
;
}
size_t
SpecialSlicePluginDynamic
::
getSerializationSize
()
const
{
size_t
serialize_size
=
0
;
return
serialize_size
;
}
void
SpecialSlicePluginDynamic
::
serialize
(
void
*
buffer
)
const
{}
nvinfer1
::
DimsExprs
SpecialSlicePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
auto
one
=
expr_builder
.
constant
(
1
);
output
.
d
[
0
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
inputs
[
1
].
d
[
0
],
*
one
);
return
output
;
}
void
SpecialSlicePluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
{}
size_t
SpecialSlicePluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
{
return
0
;
}
void
SpecialSlicePluginDynamic
::
destroy
()
{
delete
this
;
}
void
SpecialSlicePluginDynamic
::
terminate
()
{}
bool
SpecialSlicePluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
desc
,
int
nb_inputs
,
int
nb_outputs
)
{
if
(
pos
==
0
)
// slice tensor
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
if
(
pos
==
1
)
// cu_seqlen
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
}
nvinfer1
::
DataType
SpecialSlicePluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The index should be equal to 0"
));
return
input_types
[
0
];
}
template
<
typename
T
>
__global__
void
SpecialSliceKernel
(
const
T
*
slice_input
,
const
int32_t
*
cu_seqlens
,
T
*
output
)
{
const
int
hidden
=
blockDim
.
x
;
const
int
batch
=
blockIdx
.
x
;
output
[
batch
*
hidden
+
threadIdx
.
x
]
=
slice_input
[
cu_seqlens
[
batch
]
*
hidden
+
threadIdx
.
x
];
}
int
SpecialSlicePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
// (sum(S), 768, 1, 1)
auto
out_dims
=
output_desc
[
0
].
dims
;
// (batch, 768, 1, 1)
assert
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
);
const
int32_t
hidden
=
input_dims
.
d
[
1
];
const
int
num_blocks
=
out_dims
.
d
[
0
];
// batch size
const
int
num_threads
=
hidden
;
const
half
*
slice_input
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
cu_seqlens
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SpecialSliceKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
slice_input
,
cu_seqlens
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
SpecialSlicePluginDynamicCreator
::
SpecialSlicePluginDynamicCreator
()
{}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginName
()
const
{
return
"stack_plugin"
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginVersion
()
const
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
SpecialSlicePluginDynamicCreator
::
getFieldNames
()
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
{
return
new
SpecialSlicePluginDynamic
();
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
{
auto
plugin
=
new
SpecialSlicePluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
SpecialSlicePluginDynamicCreator
::
setPluginNamespace
(
const
char
*
lib_namespace
)
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginNamespace
()
const
{
return
plugin_namespace_
.
c_str
();
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h
0 → 100644
浏览文件 @
fb3c0b2f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
SpecialSlicePluginDynamic
:
public
DynamicPluginTensorRT
{
public:
SpecialSlicePluginDynamic
();
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
);
~
SpecialSlicePluginDynamic
();
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
override
;
const
char
*
getPluginType
()
const
override
;
int
getNbOutputs
()
const
override
;
int
initialize
()
override
;
void
terminate
()
override
;
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
void
destroy
()
override
;
private:
int
axis_
;
int
num_stack_
;
};
class
SpecialSlicePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
SpecialSlicePluginDynamicCreator
();
const
char
*
getPluginName
()
const
override
;
const
char
*
getPluginVersion
()
const
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
;
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
;
const
char
*
getPluginNamespace
()
const
override
;
private:
std
::
string
plugin_namespace_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
SpecialSlicePluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
fb3c0b2f
...
...
@@ -264,9 +264,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int64_t
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int32_t
>
());
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The TRT Engine OP only support float
and
int64_t input."
));
"The TRT Engine OP only support float
/int32_t/
int64_t input."
));
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录