Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fcc8a87b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fcc8a87b
编写于
6月 22, 2022
作者:
Z
zhoutianzi666
提交者:
GitHub
6月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[inference] add slice trt layer (#43648)
* add fc, multihead_mul, shape tensor infer, slice
上级
d41a9373
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
500 addition
and
219 deletion
+500
-219
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+146
-60
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+102
-59
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+76
-50
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+94
-6
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+81
-39
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
...id/tests/unittests/ir/inference/test_trt_convert_slice.py
+1
-5
未找到文件。
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
fcc8a87b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -31,12 +34,17 @@ namespace tensorrt {
class
FcOpConverter
:
public
OpConverter
{
public:
nvinfer1
::
ILayer
*
reshape_before_fc
(
nvinfer1
::
ITensor
*
before_fc
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
,
std
::
string
output_name
)
{
// add shuffle before fc
nvinfer1
::
Dims
reshape_before_fc_dim
;
reshape_before_fc_dim
.
nbDims
=
x_num_col_dims
+
3
;
// padding shape "* x q x 1 x 1"
nvinfer1
::
ITensor
*
filal_reshape_before_fc_shape_tensor
=
nullptr
;
if
(
!
engine_
->
with_dynamic_shape
())
{
for
(
int
i
=
0
;
i
<
reshape_before_fc_dim
.
nbDims
;
i
++
)
{
reshape_before_fc_dim
.
d
[
i
]
=
1
;
}
...
...
@@ -44,16 +52,39 @@ class FcOpConverter : public OpConverter {
if
(
i
<
x_num_col_dims
)
{
reshape_before_fc_dim
.
d
[
i
]
=
0
;
}
else
{
if
(
x_dim
.
d
[
i
]
<
0
)
{
reshape_before_fc_dim
.
d
[
x_num_col_dims
]
=
-
1
;
break
;
}
reshape_before_fc_dim
.
d
[
x_num_col_dims
]
*=
x_dim
.
d
[
i
];
}
}
}
else
{
std
::
vector
<
nvinfer1
::
ITensor
*>
reshape_before_fc_shape_tensor
;
nvinfer1
::
ITensor
*
input_shape_tensor
=
Shape
(
before_fc
);
for
(
int
i
=
0
;
i
<
reshape_before_fc_dim
.
nbDims
;
i
++
)
{
reshape_before_fc_shape_tensor
.
push_back
(
Add1DConstantLayer
(
1
));
}
for
(
int
i
=
0
;
i
<
x_dim
.
nbDims
;
i
++
)
{
if
(
i
<
x_num_col_dims
)
{
reshape_before_fc_shape_tensor
[
i
]
=
GetEleTensorOfShape
(
input_shape_tensor
,
i
);
}
else
{
reshape_before_fc_shape_tensor
[
x_num_col_dims
]
=
Prod
(
GetEleTensorOfShape
(
input_shape_tensor
,
i
),
reshape_before_fc_shape_tensor
[
x_num_col_dims
]);
}
}
filal_reshape_before_fc_shape_tensor
=
Concat
(
reshape_before_fc_shape_tensor
);
}
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
before_fc
);
if
(
!
engine_
->
with_dynamic_shape
())
{
reshape_before_fc_layer
->
setReshapeDimensions
(
reshape_before_fc_dim
);
}
else
{
reshape_before_fc_layer
->
setInput
(
1
,
*
filal_reshape_before_fc_shape_tensor
);
}
reshape_before_fc_layer
->
setName
(
(
"fc_op_reshape_before_fc: Shuffle (Output: "
+
output_name
+
")"
)
.
c_str
());
...
...
@@ -61,21 +92,39 @@ class FcOpConverter : public OpConverter {
}
nvinfer1
::
ILayer
*
reshape_after_fc
(
nvinfer1
::
ITensor
*
after_fc
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
)
{
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
)
{
// add shuffle after fc
nvinfer1
::
Dims
reshape_after_fc_dim
;
reshape_after_fc_dim
.
nbDims
=
x_num_col_dims
+
1
;
nvinfer1
::
ITensor
*
filal_reshape_after_fc_shape_tensor
=
nullptr
;
if
(
!
engine_
->
with_dynamic_shape
())
{
for
(
int
i
=
0
;
i
<
reshape_after_fc_dim
.
nbDims
;
i
++
)
{
reshape_after_fc_dim
.
d
[
i
]
=
0
;
}
}
else
{
std
::
vector
<
int
>
gather_indices
(
x_num_col_dims
+
1
);
std
::
iota
(
gather_indices
.
begin
(),
gather_indices
.
end
(),
0
);
filal_reshape_after_fc_shape_tensor
=
Gather
(
Shape
(
after_fc
),
gather_indices
);
}
auto
*
reshape_after_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
after_fc
);
if
(
!
engine_
->
with_dynamic_shape
())
{
reshape_after_fc_layer
->
setReshapeDimensions
(
reshape_after_fc_dim
);
}
else
{
reshape_after_fc_layer
->
setInput
(
1
,
*
filal_reshape_after_fc_shape_tensor
);
}
return
reshape_after_fc_layer
;
}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid fc op to tensorrt fc layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
@@ -93,7 +142,8 @@ class FcOpConverter : public OpConverter {
// Declare weights
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
w_name
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
Y_v
,
platform
::
errors
::
NotFound
(
Y_v
,
platform
::
errors
::
NotFound
(
"Can not find %s presistale var of fc in scope."
,
w_name
));
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
int
x_num_col_dims
=
...
...
@@ -125,7 +175,8 @@ class FcOpConverter : public OpConverter {
}
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
w_name
).
front
(),
Y_t
);
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
,
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
,
platform
::
errors
::
InvalidArgument
(
"The fc's weight should be a matrix with 2 dims, but "
"it's %d-dimensional."
,
...
...
@@ -140,7 +191,8 @@ class FcOpConverter : public OpConverter {
}
};
auto
regist_fc
=
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
auto
regist_fc
=
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
TensorRTEngine
::
Weight
&
weight
,
TensorRTEngine
::
Weight
&
bias
)
{
if
(
enable_int8
||
support_int8
)
{
...
...
@@ -148,7 +200,8 @@ class FcOpConverter : public OpConverter {
float
out_scale
=
0
;
if
(
enable_int8
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
...
...
@@ -156,9 +209,13 @@ class FcOpConverter : public OpConverter {
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Out"
));
}
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
inputs
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
inputs
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
fc_layer_int8
->
setName
(
(
"fc_op_int8_conv1x1: Convolution (Output: "
+
output_name
+
")"
)
.
c_str
());
...
...
@@ -171,21 +228,29 @@ class FcOpConverter : public OpConverter {
.
c_str
());
engine_
->
SetTensorDynamicRange
(
fc_after_reshape_int8
->
getOutput
(
0
),
out_scale
);
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_after_reshape_int8
->
getOutput
(
0
)),
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_after_reshape_int8
->
getOutput
(
0
)),
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_fc_shuffle"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_fc_shuffle"
,
{
output_name
},
test_mode
);
}
else
{
RreplenishLayerAndOutput
(
fc_after_reshape_int8
,
"fc_op_int8_reshape_after_fc: Shuffle"
,
{
output_name
},
test_mode
);
{
output_name
},
test_mode
);
}
}
else
{
// add fc layer
auto
*
fc_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
inputs
,
n_output
,
weight
.
get
(),
bias
.
get
());
auto
*
fc_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
inputs
,
n_output
,
weight
.
get
(),
bias
.
get
());
fc_layer_float
->
setName
(
(
"fc_op_float: FullyConnected (Output: "
+
output_name
+
")"
)
.
c_str
());
...
...
@@ -195,14 +260,20 @@ class FcOpConverter : public OpConverter {
fc_after_reshape_float
->
setName
(
(
"float_reshape_after_fc: Shuffle (Output: "
+
output_name
+
")"
)
.
c_str
());
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_after_reshape_float
->
getOutput
(
0
)),
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_after_reshape_float
->
getOutput
(
0
)),
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_float
,
"relu_after_fc_shuffle"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
relu_layer_float
,
"relu_after_fc_shuffle"
,
{
output_name
},
test_mode
);
}
else
{
RreplenishLayerAndOutput
(
fc_after_reshape_float
,
"shuffle_after_fc"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
fc_after_reshape_float
,
"shuffle_after_fc"
,
{
output_name
},
test_mode
);
}
}
};
...
...
@@ -251,15 +322,20 @@ class FcOpConverter : public OpConverter {
if
(
enable_int8
||
support_int8
)
{
// add conv1x1 layer
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
X
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
X
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
if
(
activation_type
==
"relu"
)
{
fc_layer_int8
->
setName
(
(
"ernie_fc_op_int8: Convolution (Output: "
+
output_name
+
")"
)
.
c_str
());
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
float
out_scale
=
0
;
...
...
@@ -271,15 +347,20 @@ class FcOpConverter : public OpConverter {
}
engine_
->
SetTensorDynamicRange
(
fc_layer_int8
->
getOutput
(
0
),
out_scale
);
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_layer_int8
->
getOutput
(
0
)),
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_layer_int8
->
getOutput
(
0
)),
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_ernie_fc_int8"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_ernie_fc_int8"
,
{
output_name
},
test_mode
);
}
else
{
RreplenishLayerAndOutput
(
fc_layer_int8
,
"ernie_fc_op_int8: Convolution"
,
{
output_name
},
test_mode
);
{
output_name
},
test_mode
);
}
}
else
{
// add fc layer
...
...
@@ -288,25 +369,30 @@ class FcOpConverter : public OpConverter {
if
(
activation_type
==
"relu"
)
{
fc_layer_float
->
setName
(
(
"ernie_fc_op_float: (Output: "
+
output_name
+
")"
).
c_str
());
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_layer_float
->
getOutput
(
0
)),
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
(
fc_layer_float
->
getOutput
(
0
)),
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_float
,
"relu_after_ernie_fc_float"
,
{
output_name
},
"relu_after_ernie_fc_float"
,
{
output_name
},
test_mode
);
}
else
{
RreplenishLayerAndOutput
(
fc_layer_float
,
"ernie_fc_op_float"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
fc_layer_float
,
"ernie_fc_op_float"
,
{
output_name
},
test_mode
);
}
}
}
else
{
// need reshape input before and after fc
PADDLE_ENFORCE_GT
(
x_dim
.
nbDims
,
x_num_col_dims
,
x_dim
.
nbDims
,
x_num_col_dims
,
platform
::
errors
::
InvalidArgument
(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d."
,
x_dim
.
nbDims
,
x_num_col_dims
));
x_dim
.
nbDims
,
x_num_col_dims
));
auto
*
reshape_before_fc_layer
=
reshape_before_fc
(
X
,
x_dim
,
x_num_col_dims
,
output_name
);
auto
*
reshape_itensor
=
reshape_before_fc_layer
->
getOutput
(
0
);
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
fcc8a87b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
...
...
@@ -19,7 +22,8 @@ namespace tensorrt {
class
MultiheadMatMulOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid multihead_mamul op to a corresponding tensorrt "
"network structure"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
...
...
@@ -49,8 +53,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
float
*
bias_data
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
);
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
weight_t
->
numel
());
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
// (hidden_in, 3, hidden_out)
const
auto
&
weight_dims
=
weight_t
->
dims
();
...
...
@@ -98,14 +102,15 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
float
dp_probs
=
1.0
/
127.0
;
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
fc_layer
->
setName
(
(
"Multihead: Convolution/FullyConnected: (Output: "
+
output_name
+
")"
)
.
c_str
());
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out_threshold in multihead layers in int8 mode"
));
float
out_scale
=
...
...
@@ -119,13 +124,19 @@ class MultiheadMatMulOpConverter : public OpConverter {
"CustomQKVToContextPluginDynamic"
,
"3"
);
assert
(
creator
!=
nullptr
);
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
if
(
qkv2context_plugin_int8
)
{
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
}
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
...
...
@@ -154,7 +165,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
engine_
->
SetTensorDynamicRange
(
max_seqlen_tensor
,
1.0
f
);
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
...
...
@@ -173,8 +185,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
// head_size,
// hidden_in]
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
three
,
int
head_number
,
int
head_size
,
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
three
,
int
head_number
,
int
head_size
,
int
hidden_in
)
{
const
int
HH
=
head_size
*
hidden_in
;
for
(
int
i
=
0
;
i
<
three
;
++
i
)
{
...
...
@@ -187,8 +202,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
};
// [3, head_number, head_size] -> [head_number, 3, head_size]
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
...
...
@@ -197,31 +212,37 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
}
};
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
three
,
head_number
,
head_size
,
hidden_in
);
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
three
,
head_number
,
head_size
,
hidden_in
);
std
::
vector
<
float
>
bias_data_tmp
;
bias_data_tmp
.
reserve
(
bias_t
->
numel
());
memcpy
(
bias_data_tmp
.
data
(),
bias_data
,
bias_t
->
numel
()
*
sizeof
(
float
));
transpose_bias_v2
(
bias_data_tmp
.
data
(),
bias_data
,
head_number
,
head_size
);
memcpy
(
bias_data_tmp
.
data
(),
bias_data
,
bias_t
->
numel
()
*
sizeof
(
float
));
transpose_bias_v2
(
bias_data_tmp
.
data
(),
bias_data
,
head_number
,
head_size
);
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
float
dp_probs
=
1.0
/
127.0
;
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
}
else
{
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
weight
,
bias
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
weight
,
bias
);
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers "
"in int8 mode"
));
...
...
@@ -245,15 +266,21 @@ class MultiheadMatMulOpConverter : public OpConverter {
int
var_seqlen
=
1
;
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
if
(
qkv2context_plugin_int8
)
{
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
}
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
...
...
@@ -274,7 +301,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
"mask_id"
);
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
...
...
@@ -290,7 +318,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
}
else
{
PADDLE_ENFORCE_EQ
(
input
->
getDimensions
().
nbDims
,
3
,
input
->
getDimensions
().
nbDims
,
3
,
platform
::
errors
::
InvalidArgument
(
"The Input dim of the MultiheadMatMul should be 3, "
"but it's (%d) now."
,
...
...
@@ -309,20 +338,24 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast
<
size_t
>
(
bias_t
->
numel
())};
// add shuffle before fc
nvinfer1
::
Dims
reshape_before_fc_dim
;
reshape_before_fc_dim
.
nbDims
=
5
;
reshape_before_fc_dim
.
d
[
0
]
=
0
;
reshape_before_fc_dim
.
d
[
1
]
=
0
;
reshape_before_fc_dim
.
d
[
2
]
=
0
;
reshape_before_fc_dim
.
d
[
3
]
=
1
;
reshape_before_fc_dim
.
d
[
4
]
=
1
;
std
::
vector
<
nvinfer1
::
ITensor
*>
reshape_before_fc_shape_tensor
;
nvinfer1
::
ITensor
*
input_shape_tensor
=
Shape
(
input
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
reshape_before_fc_shape_tensor
.
push_back
(
Add1DConstantLayer
(
1
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
reshape_before_fc_shape_tensor
[
i
]
=
GetEleTensorOfShape
(
input_shape_tensor
,
i
);
}
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
engine_
->
SetTensorDynamicRange
(
reshape_before_fc_layer
->
getOutput
(
0
),
in_scale
);
}
reshape_before_fc_layer
->
setReshapeDimensions
(
reshape_before_fc_dim
);
reshape_before_fc_layer
->
setInput
(
1
,
*
Concat
(
reshape_before_fc_shape_tensor
));
reshape_before_fc_layer
->
setName
(
(
"shuffle_before_multihead_mamul(Output: "
+
output_name
+
")"
)
.
c_str
());
...
...
@@ -331,18 +364,28 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
}
else
{
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
weight
.
get
(),
bias
.
get
());
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
weight
.
get
(),
bias
.
get
());
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers in int8 mode"
));
float
out_scale
=
...
...
@@ -369,8 +412,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
with_fp16
=
true
;
}
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden_in
,
head_number
,
head_size
,
scale
,
with_fp16
);
new
plugin
::
QkvToContextPluginDynamic
(
hidden_in
,
head_number
,
head_size
,
scale
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
2
,
plugin
);
}
}
else
{
...
...
@@ -380,8 +423,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."
));
}
RreplenishLayerAndOutput
(
layer
,
"multihead_matmul"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
layer
,
"multihead_matmul"
,
{
output_name
},
test_mode
);
}
};
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
fcc8a87b
...
...
@@ -47,14 +47,16 @@ class OpConverter {
// test_mode: whether the instance executes in an unit test.
void
ConvertOp
(
const
framework
::
proto
::
OpDesc
&
op
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
,
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
,
bool
test_mode
=
false
)
{
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
OpConverter
*
it
{
nullptr
};
if
(
op_desc
.
Type
()
==
"mul"
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Y"
).
size
(),
1UL
,
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Y"
).
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The input op mul's Input(
\"
Y
\"
)."
"size() should equal to 1, but reveceid "
...
...
@@ -70,7 +72,8 @@ class OpConverter {
"add"
,
"mul"
,
"sub"
,
"div"
,
"max"
,
"min"
,
"pow"
};
static
std
::
unordered_set
<
std
::
string
>
add_weight_op_set
{
"add"
,
"mul"
,
"sub"
,
"div"
,
"pow"
};
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Y"
).
size
(),
1UL
,
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Y"
).
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The input op's Input(
\"
Y
\"
)."
"size() should equal to 1, but reveceid "
...
...
@@ -81,63 +84,73 @@ class OpConverter {
std
::
string
Y
=
op_desc
.
Input
(
"Y"
)[
0
];
if
(
parameters
.
count
(
Y
))
{
PADDLE_ENFORCE_GT
(
add_weight_op_set
.
count
(
op_type
),
0
,
add_weight_op_set
.
count
(
op_type
),
0
,
platform
::
errors
::
Unimplemented
(
"Unsupported elementwise type %s"
,
op_type
.
c_str
()));
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
"elementwise_"
+
op_type
+
"_weight"
);
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
}
else
{
PADDLE_ENFORCE_GT
(
add_tensor_op_set
.
count
(
op_type
),
0
,
add_tensor_op_set
.
count
(
op_type
),
0
,
platform
::
errors
::
Unimplemented
(
"Unsupported elementwise type %s"
,
op_type
.
c_str
()));
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
"elementwise_"
+
op_type
+
"_tensor"
);
}
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
}
if
(
op_desc
.
Type
()
==
"depthwise_conv2d"
)
{
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
"conv2d"
);
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
}
if
(
op_desc
.
Type
()
==
"depthwise_conv2d_transpose"
)
{
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
"conv2d_transpose"
);
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
}
if
(
op_desc
.
Type
()
==
"transpose2"
)
{
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
"transpose"
);
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
}
if
(
op_desc
.
Type
()
==
"flatten2"
)
{
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
"flatten"
);
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
}
// reshape2 == reshape
if
(
op_desc
.
Type
()
==
"reshape2"
)
{
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
"reshape"
);
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
}
if
(
!
it
)
{
it
=
Registry
<
OpConverter
>::
Global
().
Lookup
(
op_desc
.
Type
());
}
PADDLE_ENFORCE_NOT_NULL
(
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
it
,
platform
::
errors
::
Unimplemented
(
"no OpConverter for optype [%s]"
,
op_desc
.
Type
()));
it
->
SetEngine
(
engine
);
...
...
@@ -214,7 +227,8 @@ class OpConverter {
// the INetwork's inputs and outputs should specified in some other modules.
void
ConvertBlock
(
const
framework
::
proto
::
BlockDesc
&
block
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
)
{
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
for
(
int
i
=
0
;
i
<
block
.
ops_size
();
i
++
)
{
const
auto
&
op
=
block
.
ops
(
i
);
...
...
@@ -224,20 +238,24 @@ class OpConverter {
// The scope here should be inited with the parameter vars.
void
ConvertBlockToTRTEngine
(
framework
::
BlockDesc
*
block_desc
,
const
framework
::
Scope
&
scope
,
framework
::
BlockDesc
*
block_desc
,
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
std
::
vector
<
std
::
string
>&
outputs
,
TensorRTEngine
*
engine
)
{
const
std
::
vector
<
std
::
string
>&
outputs
,
TensorRTEngine
*
engine
)
{
engine
->
InitNetwork
();
bool
all_dynamic_shape_set
=
true
;
for
(
auto
&
input
:
inputs
)
{
if
(
parameters
.
count
(
input
))
continue
;
auto
*
var
=
block_desc
->
FindVar
(
input
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"no variable called %s in block."
,
var
,
platform
::
errors
::
NotFound
(
"no variable called %s in block."
,
input
.
c_str
()));
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
var
->
GetType
(),
FluidDT
::
VarType_Type_LOD_TENSOR
,
platform
::
errors
::
InvalidArgument
(
"TensorRT engine only takes "
"LoDTensor as input"
));
auto
var_shape
=
var
->
GetShape
();
...
...
@@ -262,7 +280,8 @@ class OpConverter {
}
else
{
input_shape
.
push_back
(
min_input_shape
[
i
]);
// the i dimension should be same.
PADDLE_ENFORCE_EQ
(
min_input_shape
[
i
],
optim_input_shape
[
i
],
PADDLE_ENFORCE_EQ
(
min_input_shape
[
i
],
optim_input_shape
[
i
],
platform
::
errors
::
InvalidArgument
(
"The dim (%d) of the min_input_shape and "
"optim_input_shape should be same."
));
...
...
@@ -282,7 +301,8 @@ class OpConverter {
Vec2TRT_Dims
(
var_shape
,
input
));
}
}
PADDLE_ENFORCE_EQ
(
all_dynamic_shape_set
,
true
,
PADDLE_ENFORCE_EQ
(
all_dynamic_shape_set
,
true
,
platform
::
errors
::
InvalidArgument
(
"some trt inputs dynamic shape info not set, "
"check the INFO log above for more details."
));
...
...
@@ -297,7 +317,8 @@ class OpConverter {
// rank(result) = rank(input)
nvinfer1
::
ITensor
*
Gather
(
nvinfer1
::
ITensor
*
input
,
const
std
::
vector
<
int32_t
>
indices
,
int
axis
=
0
)
{
const
std
::
vector
<
int32_t
>
indices
,
int
axis
=
0
)
{
auto
*
indices_tensor
=
Add1DConstantLayer
(
indices
,
" "
);
auto
*
result
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Gather
,
*
input
,
*
indices_tensor
,
axis
)
...
...
@@ -326,8 +347,8 @@ class OpConverter {
// Concat not make rank changed
nvinfer1
::
ITensor
*
Concat
(
const
std
::
vector
<
nvinfer1
::
ITensor
*>&
inputs
,
int
axis
=
0
)
{
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Concatenation
,
inputs
.
data
(),
inputs
.
size
());
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Concatenation
,
inputs
.
data
(),
inputs
.
size
());
if
(
axis
!=
0
)
layer
->
setAxis
(
axis
);
nvinfer1
::
ITensor
*
c
=
layer
->
getOutput
(
0
);
return
c
;
...
...
@@ -335,48 +356,48 @@ class OpConverter {
nvinfer1
::
ITensor
*
Sum
(
nvinfer1
::
ITensor
*
a
,
nvinfer1
::
ITensor
*
b
)
{
nvinfer1
::
ITensor
*
c
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kSUM
)
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kSUM
)
->
getOutput
(
0
);
return
c
;
}
nvinfer1
::
ITensor
*
Prod
(
nvinfer1
::
ITensor
*
a
,
nvinfer1
::
ITensor
*
b
)
{
nvinfer1
::
ITensor
*
c
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kPROD
)
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kPROD
)
->
getOutput
(
0
);
return
c
;
}
nvinfer1
::
ITensor
*
Min
(
nvinfer1
::
ITensor
*
a
,
nvinfer1
::
ITensor
*
b
)
{
nvinfer1
::
ITensor
*
c
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kMIN
)
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kMIN
)
->
getOutput
(
0
);
return
c
;
}
nvinfer1
::
ITensor
*
Max
(
nvinfer1
::
ITensor
*
a
,
nvinfer1
::
ITensor
*
b
)
{
nvinfer1
::
ITensor
*
c
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kMAX
)
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kMAX
)
->
getOutput
(
0
);
return
c
;
}
nvinfer1
::
ITensor
*
Sub
(
nvinfer1
::
ITensor
*
a
,
nvinfer1
::
ITensor
*
b
)
{
nvinfer1
::
ITensor
*
c
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kSUB
)
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kSUB
)
->
getOutput
(
0
);
return
c
;
}
nvinfer1
::
ITensor
*
Div
(
nvinfer1
::
ITensor
*
a
,
nvinfer1
::
ITensor
*
b
)
{
nvinfer1
::
ITensor
*
c
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kDIV
)
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
a
,
*
b
,
nvinfer1
::
ElementWiseOperation
::
kDIV
)
->
getOutput
(
0
);
return
c
;
}
...
...
@@ -390,10 +411,14 @@ class OpConverter {
// Get element tensor of 1D shape tensor
nvinfer1
::
ITensor
*
GetEleTensorOfShape
(
nvinfer1
::
ITensor
*
shape_tensor
,
int
index
,
bool
is_scalar
=
false
)
{
int
index
,
bool
is_scalar
=
false
)
{
auto
*
tensor
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Gather
,
*
shape_tensor
,
*
Add1DConstantLayer
(
index
,
" "
,
is_scalar
),
0
)
TRT_ENGINE_ADD_LAYER
(
engine_
,
Gather
,
*
shape_tensor
,
*
Add1DConstantLayer
(
index
,
" "
,
is_scalar
),
0
)
->
getOutput
(
0
);
return
tensor
;
}
...
...
@@ -403,8 +428,8 @@ class OpConverter {
const
std
::
vector
<
int32_t
>&
weight_dims
,
const
std
::
string
&
weight_name
)
{
std
::
unique_ptr
<
framework
::
Tensor
>
tmp_tensor
(
new
framework
::
Tensor
());
int
data_size
=
std
::
accumulate
(
weight_dims
.
begin
(),
weight_dims
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
data_size
=
std
::
accumulate
(
weight_dims
.
begin
(),
weight_dims
.
end
(),
1
,
std
::
multiplies
<
int
>
());
tmp_tensor
->
Resize
({
data_size
});
auto
*
tmp_data
=
tmp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
data_size
;
i
++
)
{
...
...
@@ -489,7 +514,8 @@ class OpConverter {
}
void
RreplenishLayerAndOutput
(
nvinfer1
::
ILayer
*
layer
,
const
std
::
string
&
layer_type
,
nvinfer1
::
ILayer
*
layer
,
const
std
::
string
&
layer_type
,
const
std
::
vector
<
std
::
string
>&
output_tensor_names
,
bool
test_mode
=
false
)
{
size_t
num_out
=
output_tensor_names
.
size
();
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
fcc8a87b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -22,7 +19,8 @@ namespace tensorrt {
class
SliceOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
// This OP is implemented by trt dynamic shpae plugin.
// Dynamic shape plugin requires TRT version greater than 6.0.
VLOG
(
4
)
<<
"convert slice op to tensorrt layer"
;
...
...
@@ -63,28 +61,118 @@ class SliceOpConverter : public OpConverter {
}
ends
[
i
]
=
std
::
min
(
ends
[
i
],
input_dims
.
d
[
axes
[
i
]]);
PADDLE_ENFORCE_GT
(
ends
[
i
],
starts
[
i
],
ends
[
i
],
starts
[
i
],
platform
::
errors
::
InvalidArgument
(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d."
,
ends
[
i
],
starts
[
i
]));
ends
[
i
],
starts
[
i
]));
}
}
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
#if IS_TRT_VERSION_GE(6000)
auto
nchw_input_dims
=
input
->
getDimensions
();
nvinfer1
::
Dims
trt_start_dims
;
trt_start_dims
.
nbDims
=
nchw_input_dims
.
nbDims
;
memset
(
trt_start_dims
.
d
,
0
,
sizeof
(
int32_t
)
*
nchw_input_dims
.
nbDims
);
nvinfer1
::
Dims
trt_size_dims
=
trt_start_dims
;
nvinfer1
::
Dims
trt_end_dims
=
trt_start_dims
;
nvinfer1
::
Dims
trt_step_dims
=
trt_start_dims
;
for
(
int
i
=
0
;
i
<
trt_step_dims
.
nbDims
;
i
++
)
trt_step_dims
.
d
[
i
]
=
1
;
// input : [N,C,H,W]
bool
has_neg_indices
=
false
;
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
trt_axis
=
axes
[
i
];
trt_start_dims
.
d
[
trt_axis
]
=
starts
[
i
];
trt_end_dims
.
d
[
trt_axis
]
=
ends
[
i
];
if
(
starts
[
i
]
<
0
||
ends
[
i
]
<
0
)
has_neg_indices
=
true
;
}
auto
*
shape_tensor
=
Shape
(
input
);
auto
*
start_tensor
=
Add1DConstantLayer
(
trt_start_dims
);
if
(
has_neg_indices
)
{
start_tensor
=
FixNegIndices
(
shape_tensor
,
start_tensor
);
}
std
::
vector
<
nvinfer1
::
ITensor
*>
end_vec_tensor
;
for
(
int
i
=
0
;
i
<
trt_end_dims
.
nbDims
;
i
++
)
{
end_vec_tensor
.
push_back
(
GetEleTensorOfShape
(
shape_tensor
,
i
));
}
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
trt_axis
=
axes
[
i
];
if
(
ends
[
i
]
>=
0
)
{
end_vec_tensor
[
trt_axis
]
=
Add1DConstantLayer
(
ends
[
i
]);
}
else
{
end_vec_tensor
[
trt_axis
]
=
Sum
(
end_vec_tensor
[
trt_axis
],
Add1DConstantLayer
(
ends
[
i
]));
}
}
// CI failed in trt 6015 but success in 7134, may be a trt bug
#if IS_TRT_VERSION_GE(7134)
auto
*
size_tensor
=
Sub
(
Min
(
Concat
(
end_vec_tensor
),
shape_tensor
),
start_tensor
);
#else
auto
*
size_tensor
=
Sub
(
Concat
(
end_vec_tensor
),
start_tensor
);
#endif
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
*
input
,
trt_start_dims
,
trt_size_dims
,
trt_step_dims
);
layer
->
setInput
(
1
,
*
start_tensor
);
layer
->
setInput
(
2
,
*
size_tensor
);
if
(
decrease_axises
.
size
()
>
0
)
{
std
::
vector
<
int32_t
>
gather_indices
;
for
(
int
i
=
0
;
i
<
trt_size_dims
.
nbDims
;
i
++
)
{
if
(
decrease_axises
.
end
()
!=
std
::
find
(
decrease_axises
.
begin
(),
decrease_axises
.
end
(),
i
))
continue
;
gather_indices
.
push_back
(
i
);
}
if
(
gather_indices
.
empty
())
gather_indices
.
push_back
(
decrease_axises
[
0
]);
auto
real_size_tensor
=
Gather
(
size_tensor
,
gather_indices
);
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
layer
->
getOutput
(
0
));
layer
->
setInput
(
1
,
*
real_size_tensor
);
}
#else
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
int
decrease_axis
=
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
#endif
}
else
{
#if IS_TRT_VERSION_GE(6000)
auto
chw_input_dims
=
input
->
getDimensions
();
nvinfer1
::
Dims
trt_start_dims
;
trt_start_dims
.
nbDims
=
chw_input_dims
.
nbDims
;
memset
(
trt_start_dims
.
d
,
0
,
sizeof
(
int32_t
)
*
chw_input_dims
.
nbDims
);
nvinfer1
::
Dims
trt_size_dims
=
chw_input_dims
;
nvinfer1
::
Dims
trt_step_dims
;
trt_step_dims
.
nbDims
=
chw_input_dims
.
nbDims
;
for
(
int
i
=
0
;
i
<
trt_step_dims
.
nbDims
;
i
++
)
trt_step_dims
.
d
[
i
]
=
1
;
// input : [C,H,W]
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
trt_axis
=
axes
[
i
]
-
1
;
trt_start_dims
.
d
[
trt_axis
]
=
starts
[
i
];
trt_size_dims
.
d
[
trt_axis
]
=
ends
[
i
]
-
starts
[
i
];
}
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
*
input
,
trt_start_dims
,
trt_size_dims
,
trt_step_dims
);
#else
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePlugin
*
plugin
=
new
plugin
::
SlicePlugin
(
starts
,
ends
,
axes
,
with_fp16
);
layer
=
engine_
->
AddPlugin
(
&
input
,
1
,
plugin
);
#endif
}
RreplenishLayerAndOutput
(
layer
,
"slice"
,
{
output_name
},
test_mode
);
}
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
fcc8a87b
...
...
@@ -49,7 +49,8 @@ void TensorRTEngine::InitNetwork() {
optim_profiles_
[
i
]
=
infer_builder_
->
createOptimizationProfile
();
}
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
*
buffers
,
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
*
buffers
,
cudaStream_t
stream
)
{
freshDeviceId
();
auto
infer_context
=
context
();
...
...
@@ -129,14 +130,32 @@ void TensorRTEngine::FreezeNetwork() {
}
#if IS_TRT_VERSION_GE(5122)
auto
is_layer_int8
=
[
&
](
nvinfer1
::
ILayer
*
layer
)
->
bool
{
auto
layer_int8_fallback
=
[
&
](
nvinfer1
::
ILayer
*
layer
)
->
bool
{
if
(
layer
->
getType
()
==
nvinfer1
::
LayerType
::
kSHAPE
)
{
return
false
;
}
bool
all_int
=
true
;
for
(
int
j
=
0
;
j
<
layer
->
getNbInputs
();
j
++
)
{
auto
*
temp_in
=
layer
->
getInput
(
j
);
if
(
temp_in
->
getType
()
!=
nvinfer1
::
DataType
::
kINT32
)
{
all_int
=
false
;
}
}
for
(
int
j
=
0
;
j
<
layer
->
getNbOutputs
();
j
++
)
{
auto
*
temp_out
=
layer
->
getOutput
(
j
);
if
(
temp_out
->
getType
()
!=
nvinfer1
::
DataType
::
kINT32
)
{
all_int
=
false
;
}
}
if
(
all_int
)
return
false
;
for
(
int
j
=
0
;
j
<
layer
->
getNbInputs
();
j
++
)
{
auto
*
temp_in
=
layer
->
getInput
(
j
);
if
(
!
temp_in
->
dynamicRangeIsSet
())
{
VLOG
(
1
)
<<
"Layer(Name: "
<<
layer
->
getName
()
<<
") is set to float32 because its input("
<<
temp_in
->
getName
()
<<
") doesn't have dynamic range."
;
return
fals
e
;
return
tru
e
;
}
}
for
(
int
j
=
0
;
j
<
layer
->
getNbOutputs
();
j
++
)
{
...
...
@@ -145,10 +164,10 @@ void TensorRTEngine::FreezeNetwork() {
VLOG
(
1
)
<<
"Layer(Name: "
<<
layer
->
getName
()
<<
") is set to float32 because its output("
<<
temp_out
->
getName
()
<<
") doesn't have dynamic range."
;
return
fals
e
;
return
tru
e
;
}
}
return
tru
e
;
return
fals
e
;
};
// If a layer's output is the network's output, or not all of its inputs
// and outputs have scales,
...
...
@@ -157,7 +176,7 @@ void TensorRTEngine::FreezeNetwork() {
int
layers_no_int8
=
0
;
for
(
int
i
=
0
;
i
<
network
()
->
getNbLayers
();
i
++
)
{
auto
layer
=
network
()
->
getLayer
(
i
);
if
(
!
is_layer_int8
(
layer
))
{
if
(
layer_int8_fallback
(
layer
))
{
layer
->
setPrecision
(
nvinfer1
::
DataType
::
kFLOAT
);
++
layers_no_int8
;
}
...
...
@@ -208,7 +227,8 @@ void TensorRTEngine::FreezeNetwork() {
for
(
auto
&
input
:
min_input_shape_
)
{
#if IS_TRT_VERSION_LT(7000)
// trt6 will check all_of input > 0
if
(
!
(
std
::
all_of
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
if
(
!
(
std
::
all_of
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
[](
int
x
)
{
return
x
>
0
;
})
&&
std
::
all_of
(
max_input_shape_
[
input
.
first
].
begin
(),
max_input_shape_
[
input
.
first
].
end
(),
...
...
@@ -225,13 +245,16 @@ void TensorRTEngine::FreezeNetwork() {
<<
", opt: "
<<
Vec2Str
(
optim_input_shape_
[
input
.
first
]);
optim_profiles_
[
i
]
->
setDimensions
(
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMIN
,
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMIN
,
Vec2TRT_Dims
(
input
.
second
,
input
.
first
,
true
));
optim_profiles_
[
i
]
->
setDimensions
(
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMAX
,
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMAX
,
Vec2TRT_Dims
(
max_input_shape_
[
input
.
first
],
input
.
first
,
true
));
optim_profiles_
[
i
]
->
setDimensions
(
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kOPT
,
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kOPT
,
Vec2TRT_Dims
(
optim_input_shape_
[
input
.
first
],
input
.
first
,
true
));
}
infer_builder_config_
->
addOptimizationProfile
(
optim_profiles_
[
i
]);
...
...
@@ -265,7 +288,8 @@ void TensorRTEngine::FreezeNetwork() {
#endif
PADDLE_ENFORCE_NOT_NULL
(
infer_engine_
,
platform
::
errors
::
Fatal
(
infer_engine_
,
platform
::
errors
::
Fatal
(
"Build TensorRT cuda engine failed! Please recheck "
"you configurations related to paddle-TensorRT."
));
...
...
@@ -282,16 +306,19 @@ void TensorRTEngine::FreezeNetwork() {
nvinfer1
::
ITensor
*
TensorRTEngine
::
DeclareInput
(
const
std
::
string
&
name
,
nvinfer1
::
DataType
dtype
,
const
nvinfer1
::
Dims
&
dims
)
{
PADDLE_ENFORCE_EQ
(
network
()
!=
nullptr
,
true
,
PADDLE_ENFORCE_EQ
(
network
()
!=
nullptr
,
true
,
platform
::
errors
::
InvalidArgument
(
"The TRT network should be initialized first."
));
auto
*
input
=
network
()
->
addInput
(
name
.
c_str
(),
dtype
,
dims
);
PADDLE_ENFORCE_NOT_NULL
(
input
,
platform
::
errors
::
InvalidArgument
(
"Adding input %s failed in "
input
,
platform
::
errors
::
InvalidArgument
(
"Adding input %s failed in "
"TensorRT inference network. "
"Please recheck your input."
,
name
));
PADDLE_ENFORCE_EQ
(
input
->
isNetworkInput
(),
true
,
PADDLE_ENFORCE_EQ
(
input
->
isNetworkInput
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Input %s is not the input of TRT inference network. "
"Please recheck your input."
,
...
...
@@ -300,22 +327,26 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
return
input
;
}
void
TensorRTEngine
::
DeclareOutput
(
const
nvinfer1
::
ILayer
*
layer
,
int
offset
,
void
TensorRTEngine
::
DeclareOutput
(
const
nvinfer1
::
ILayer
*
layer
,
int
offset
,
const
std
::
string
&
name
)
{
auto
*
output
=
layer
->
getOutput
(
offset
);
SetITensor
(
name
,
output
);
PADDLE_ENFORCE_NOT_NULL
(
output
,
platform
::
errors
::
InvalidArgument
(
output
,
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be null."
,
name
));
output
->
setName
(
name
.
c_str
());
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be the input "
"of the network at the same time."
,
name
));
network
()
->
markOutput
(
*
output
);
PADDLE_ENFORCE_EQ
(
output
->
isNetworkOutput
(),
true
,
output
->
isNetworkOutput
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should be the output of the network."
,
name
));
...
...
@@ -324,10 +355,12 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
void
TensorRTEngine
::
DeclareOutput
(
const
std
::
string
&
name
)
{
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
output
,
platform
::
errors
::
InvalidArgument
(
output
,
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be null."
,
name
));
output
->
setName
(
name
.
c_str
());
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be the input "
"of the network at the same time."
,
...
...
@@ -338,17 +371,20 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
void
TensorRTEngine
::
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
)
{
PADDLE_ENFORCE_NOT_NULL
(
tensor
,
platform
::
errors
::
InvalidArgument
(
tensor
,
platform
::
errors
::
InvalidArgument
(
"Tensor named %s of TRT engine should not be null."
,
name
));
PADDLE_ENFORCE_EQ
(
0
,
itensor_map_
.
count
(
name
),
0
,
itensor_map_
.
count
(
name
),
platform
::
errors
::
InvalidArgument
(
"Tensor named %s of TRT engine should not be duplicated"
,
name
));
itensor_map_
[
name
]
=
tensor
;
}
nvinfer1
::
ITensor
*
TensorRTEngine
::
GetITensor
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
itensor_map_
.
count
(
name
),
true
,
PADDLE_ENFORCE_EQ
(
itensor_map_
.
count
(
name
),
true
,
platform
::
errors
::
NotFound
(
"Tensor named %s is not found in TRT engine"
,
name
));
return
itensor_map_
[
name
];
...
...
@@ -365,15 +401,16 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
std
::
string
splitter
=
"__"
;
std
::
string
name_with_suffix
=
name
+
splitter
+
name_suffix
;
platform
::
CPUPlace
cpu_place
;
PADDLE_ENFORCE_EQ
(
weight_map
.
count
(
name_with_suffix
),
0
,
PADDLE_ENFORCE_EQ
(
weight_map
.
count
(
name_with_suffix
),
0
,
platform
::
errors
::
AlreadyExists
(
"The weight named %s is set into the weight map "
"twice in TRT OP converter."
,
name_with_suffix
));
weight_map
[
name_with_suffix
].
reset
(
new
framework
::
Tensor
());
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
->
dims
());
paddle
::
framework
::
TensorCopySync
(
*
weight_tensor
,
cpu_place
,
weight_map
[
name_with_suffix
].
get
());
paddle
::
framework
::
TensorCopySync
(
*
weight_tensor
,
cpu_place
,
weight_map
[
name_with_suffix
].
get
());
float
*
weight_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float
>
(
cpu_place
);
name_suffix_counter
+=
1
;
...
...
@@ -383,21 +420,24 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRT
*
plugin
)
{
owned_plugin_
.
emplace_back
(
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
}
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPluginV2Ext
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRTV2Ext
*
plugin
)
{
owned_plugin_v2ext_
.
emplace_back
(
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
}
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPluginV2IOExt
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
IPluginV2IOExt
*
plugin
)
{
owned_plugin_v2ioext_
.
emplace_back
(
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
...
...
@@ -406,10 +446,12 @@ nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt(
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
cudaGetDeviceCount
(
&
count
);
PADDLE_ENFORCE_LT
(
device_id_
,
count
,
PADDLE_ENFORCE_LT
(
device_id_
,
count
,
platform
::
errors
::
OutOfRange
(
"Device id %d exceeds the current device count: %d."
,
device_id_
,
count
));
device_id_
,
count
));
platform
::
SetDeviceId
(
device_id_
);
}
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
浏览文件 @
fcc8a87b
...
...
@@ -62,7 +62,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
for
axes
in
[[
0
,
1
],
[
1
,
3
],
[
2
,
3
]]:
for
starts
in
[[
0
,
1
]]:
for
ends
in
[[
2
,
2
],
[
5
,
5
]]:
for
ends
in
[[
2
,
2
],
[
5
,
5
]
,
[
1
,
-
1
]
]:
for
decrease_axis
in
[[],
[
1
],
[
2
],
[
-
1
],
[
-
100
]]:
for
infer_flags
in
[[
-
1
]]:
dics
=
[{
...
...
@@ -118,10 +118,6 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
return
0
,
3
if
dynamic_shape
==
False
and
len
(
attrs
[
0
][
"decrease_axis"
])
!=
0
:
return
0
,
3
if
dynamic_shape
:
for
i
in
range
(
len
(
attrs
[
0
][
"starts"
])):
if
attrs
[
0
][
"starts"
][
i
]
<
0
or
attrs
[
0
][
"ends"
][
i
]
<
0
:
return
0
,
3
if
not
dynamic_shape
:
for
x
in
attrs
[
0
][
"axes"
]:
if
x
==
0
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录