Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
245005d4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
245005d4
编写于
8月 04, 2022
作者:
Z
zhoutianzi666
提交者:
GitHub
8月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-TRT][cherry pick] Slice to 2.3 (#44757)
* slice_to_2.3
上级
7cdce09b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
494 addition
and
252 deletion
+494
-252
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+143
-60
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+100
-59
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+119
-44
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+81
-39
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+25
-17
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
...id/tests/unittests/ir/inference/test_trt_convert_slice.py
+26
-33
未找到文件。
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
245005d4
...
@@ -34,29 +34,59 @@ namespace tensorrt {
...
@@ -34,29 +34,59 @@ namespace tensorrt {
class
FcOpConverter
:
public
OpConverter
{
class
FcOpConverter
:
public
OpConverter
{
public:
public:
nvinfer1
::
ILayer
*
reshape_before_fc
(
nvinfer1
::
ITensor
*
before_fc
,
nvinfer1
::
ILayer
*
reshape_before_fc
(
nvinfer1
::
ITensor
*
before_fc
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
,
std
::
string
output_name
)
{
std
::
string
output_name
)
{
// add shuffle before fc
// add shuffle before fc
nvinfer1
::
Dims
reshape_before_fc_dim
;
nvinfer1
::
Dims
reshape_before_fc_dim
;
reshape_before_fc_dim
.
nbDims
=
x_num_col_dims
+
3
;
reshape_before_fc_dim
.
nbDims
=
x_num_col_dims
+
3
;
// padding shape "* x q x 1 x 1"
// padding shape "* x q x 1 x 1"
for
(
int
i
=
0
;
i
<
reshape_before_fc_dim
.
nbDims
;
i
++
)
{
reshape_before_fc_dim
.
d
[
i
]
=
1
;
nvinfer1
::
ITensor
*
filal_reshape_before_fc_shape_tensor
=
nullptr
;
}
for
(
int
i
=
0
;
i
<
x_dim
.
nbDims
;
i
++
)
{
if
(
!
engine_
->
with_dynamic_shape
())
{
if
(
i
<
x_num_col_dims
)
{
for
(
int
i
=
0
;
i
<
reshape_before_fc_dim
.
nbDims
;
i
++
)
{
reshape_before_fc_dim
.
d
[
i
]
=
0
;
reshape_before_fc_dim
.
d
[
i
]
=
1
;
}
else
{
}
if
(
x_dim
.
d
[
i
]
<
0
)
{
for
(
int
i
=
0
;
i
<
x_dim
.
nbDims
;
i
++
)
{
reshape_before_fc_dim
.
d
[
x_num_col_dims
]
=
-
1
;
if
(
i
<
x_num_col_dims
)
{
break
;
reshape_before_fc_dim
.
d
[
i
]
=
0
;
}
else
{
reshape_before_fc_dim
.
d
[
x_num_col_dims
]
*=
x_dim
.
d
[
i
];
}
}
}
else
{
std
::
vector
<
nvinfer1
::
ITensor
*>
reshape_before_fc_shape_tensor
;
nvinfer1
::
ITensor
*
input_shape_tensor
=
Shape
(
before_fc
);
for
(
int
i
=
0
;
i
<
reshape_before_fc_dim
.
nbDims
;
i
++
)
{
reshape_before_fc_shape_tensor
.
push_back
(
Add1DConstantLayer
(
1
));
}
for
(
int
i
=
0
;
i
<
x_dim
.
nbDims
;
i
++
)
{
if
(
i
<
x_num_col_dims
)
{
reshape_before_fc_shape_tensor
[
i
]
=
GetEleTensorOfShape
(
input_shape_tensor
,
i
);
}
else
{
reshape_before_fc_shape_tensor
[
x_num_col_dims
]
=
Prod
(
GetEleTensorOfShape
(
input_shape_tensor
,
i
),
reshape_before_fc_shape_tensor
[
x_num_col_dims
]);
// If not set, test_trt_matmul_quant_dequant in trt 6015 will fail
reshape_before_fc_shape_tensor
[
x_num_col_dims
]
->
setType
(
nvinfer1
::
DataType
::
kINT32
);
}
}
reshape_before_fc_dim
.
d
[
x_num_col_dims
]
*=
x_dim
.
d
[
i
];
}
}
filal_reshape_before_fc_shape_tensor
=
Concat
(
reshape_before_fc_shape_tensor
);
}
}
auto
*
reshape_before_fc_layer
=
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
before_fc
);
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
before_fc
);
reshape_before_fc_layer
->
setReshapeDimensions
(
reshape_before_fc_dim
);
if
(
!
engine_
->
with_dynamic_shape
())
{
reshape_before_fc_layer
->
setReshapeDimensions
(
reshape_before_fc_dim
);
}
else
{
reshape_before_fc_layer
->
setInput
(
1
,
*
filal_reshape_before_fc_shape_tensor
);
}
reshape_before_fc_layer
->
setName
(
reshape_before_fc_layer
->
setName
(
(
"fc_op_reshape_before_fc: Shuffle (Output: "
+
output_name
+
")"
)
(
"fc_op_reshape_before_fc: Shuffle (Output: "
+
output_name
+
")"
)
.
c_str
());
.
c_str
());
...
@@ -64,21 +94,37 @@ class FcOpConverter : public OpConverter {
...
@@ -64,21 +94,37 @@ class FcOpConverter : public OpConverter {
}
}
nvinfer1
::
ILayer
*
reshape_after_fc
(
nvinfer1
::
ITensor
*
after_fc
,
nvinfer1
::
ILayer
*
reshape_after_fc
(
nvinfer1
::
ITensor
*
after_fc
,
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
)
{
nvinfer1
::
Dims
x_dim
,
int
x_num_col_dims
)
{
// add shuffle after fc
// add shuffle after fc
nvinfer1
::
Dims
reshape_after_fc_dim
;
nvinfer1
::
Dims
reshape_after_fc_dim
;
reshape_after_fc_dim
.
nbDims
=
x_num_col_dims
+
1
;
reshape_after_fc_dim
.
nbDims
=
x_num_col_dims
+
1
;
for
(
int
i
=
0
;
i
<
reshape_after_fc_dim
.
nbDims
;
i
++
)
{
reshape_after_fc_dim
.
d
[
i
]
=
0
;
nvinfer1
::
ITensor
*
filal_reshape_after_fc_shape_tensor
=
nullptr
;
if
(
!
engine_
->
with_dynamic_shape
())
{
for
(
int
i
=
0
;
i
<
reshape_after_fc_dim
.
nbDims
;
i
++
)
{
reshape_after_fc_dim
.
d
[
i
]
=
0
;
}
}
else
{
std
::
vector
<
int
>
gather_indices
(
x_num_col_dims
+
1
);
std
::
iota
(
gather_indices
.
begin
(),
gather_indices
.
end
(),
0
);
filal_reshape_after_fc_shape_tensor
=
Gather
(
Shape
(
after_fc
),
gather_indices
);
}
}
auto
*
reshape_after_fc_layer
=
auto
*
reshape_after_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
after_fc
);
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
after_fc
);
reshape_after_fc_layer
->
setReshapeDimensions
(
reshape_after_fc_dim
);
if
(
!
engine_
->
with_dynamic_shape
())
{
reshape_after_fc_layer
->
setReshapeDimensions
(
reshape_after_fc_dim
);
}
else
{
reshape_after_fc_layer
->
setInput
(
1
,
*
filal_reshape_after_fc_shape_tensor
);
}
return
reshape_after_fc_layer
;
return
reshape_after_fc_layer
;
}
}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid fc op to tensorrt fc layer without bias"
;
VLOG
(
3
)
<<
"convert a fluid fc op to tensorrt fc layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
...
@@ -96,8 +142,9 @@ class FcOpConverter : public OpConverter {
...
@@ -96,8 +142,9 @@ class FcOpConverter : public OpConverter {
// Declare weights
// Declare weights
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
w_name
).
front
());
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
w_name
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
Y_v
,
platform
::
errors
::
NotFound
(
Y_v
,
"Can not find %s presistale var of fc in scope."
,
w_name
));
platform
::
errors
::
NotFound
(
"Can not find %s presistale var of fc in scope."
,
w_name
));
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
int
x_num_col_dims
=
int
x_num_col_dims
=
op_desc
.
HasAttr
(
"x_num_col_dims"
)
op_desc
.
HasAttr
(
"x_num_col_dims"
)
...
@@ -128,7 +175,8 @@ class FcOpConverter : public OpConverter {
...
@@ -128,7 +175,8 @@ class FcOpConverter : public OpConverter {
}
}
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
w_name
).
front
(),
Y_t
);
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
w_name
).
front
(),
Y_t
);
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
,
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The fc's weight should be a matrix with 2 dims, but "
"The fc's weight should be a matrix with 2 dims, but "
"it's %d-dimensional."
,
"it's %d-dimensional."
,
...
@@ -143,7 +191,8 @@ class FcOpConverter : public OpConverter {
...
@@ -143,7 +191,8 @@ class FcOpConverter : public OpConverter {
}
}
};
};
auto
regist_fc
=
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
auto
regist_fc
=
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
TensorRTEngine
::
Weight
&
weight
,
TensorRTEngine
::
Weight
&
weight
,
TensorRTEngine
::
Weight
&
bias
)
{
TensorRTEngine
::
Weight
&
bias
)
{
if
(
enable_int8
||
support_int8
)
{
if
(
enable_int8
||
support_int8
)
{
...
@@ -151,7 +200,8 @@ class FcOpConverter : public OpConverter {
...
@@ -151,7 +200,8 @@ class FcOpConverter : public OpConverter {
float
out_scale
=
0
;
float
out_scale
=
0
;
if
(
enable_int8
)
{
if
(
enable_int8
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
"must have out threshold in fc layers in int8 mode"
));
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
...
@@ -159,9 +209,13 @@ class FcOpConverter : public OpConverter {
...
@@ -159,9 +209,13 @@ class FcOpConverter : public OpConverter {
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Out"
));
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Out"
));
}
}
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
auto
*
fc_layer_int8
=
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
inputs
,
n_output
,
Convolution
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
*
inputs
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
fc_layer_int8
->
setName
(
fc_layer_int8
->
setName
(
(
"fc_op_int8_conv1x1: Convolution (Output: "
+
output_name
+
")"
)
(
"fc_op_int8_conv1x1: Convolution (Output: "
+
output_name
+
")"
)
.
c_str
());
.
c_str
());
...
@@ -174,21 +228,29 @@ class FcOpConverter : public OpConverter {
...
@@ -174,21 +228,29 @@ class FcOpConverter : public OpConverter {
.
c_str
());
.
c_str
());
engine_
->
SetTensorDynamicRange
(
fc_after_reshape_int8
->
getOutput
(
0
),
engine_
->
SetTensorDynamicRange
(
fc_after_reshape_int8
->
getOutput
(
0
),
out_scale
);
out_scale
);
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
engine_
,
Activation
,
*
(
fc_after_reshape_int8
->
getOutput
(
0
)),
TRT_ENGINE_ADD_LAYER
(
engine_
,
nvinfer1
::
ActivationType
::
kRELU
);
Activation
,
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_fc_shuffle"
,
*
(
fc_after_reshape_int8
->
getOutput
(
0
)),
{
output_name
},
test_mode
);
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_fc_shuffle"
,
{
output_name
},
test_mode
);
}
else
{
}
else
{
RreplenishLayerAndOutput
(
fc_after_reshape_int8
,
RreplenishLayerAndOutput
(
fc_after_reshape_int8
,
"fc_op_int8_reshape_after_fc: Shuffle"
,
"fc_op_int8_reshape_after_fc: Shuffle"
,
{
output_name
},
test_mode
);
{
output_name
},
test_mode
);
}
}
}
else
{
}
else
{
// add fc layer
// add fc layer
auto
*
fc_layer_float
=
auto
*
fc_layer_float
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
inputs
,
n_output
,
FullyConnected
,
weight
.
get
(),
bias
.
get
());
*
inputs
,
n_output
,
weight
.
get
(),
bias
.
get
());
fc_layer_float
->
setName
(
fc_layer_float
->
setName
(
(
"fc_op_float: FullyConnected (Output: "
+
output_name
+
")"
)
(
"fc_op_float: FullyConnected (Output: "
+
output_name
+
")"
)
.
c_str
());
.
c_str
());
...
@@ -198,14 +260,20 @@ class FcOpConverter : public OpConverter {
...
@@ -198,14 +260,20 @@ class FcOpConverter : public OpConverter {
fc_after_reshape_float
->
setName
(
fc_after_reshape_float
->
setName
(
(
"float_reshape_after_fc: Shuffle (Output: "
+
output_name
+
")"
)
(
"float_reshape_after_fc: Shuffle (Output: "
+
output_name
+
")"
)
.
c_str
());
.
c_str
());
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
TRT_ENGINE_ADD_LAYER
(
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
engine_
,
Activation
,
*
(
fc_after_reshape_float
->
getOutput
(
0
)),
TRT_ENGINE_ADD_LAYER
(
engine_
,
nvinfer1
::
ActivationType
::
kRELU
);
Activation
,
RreplenishLayerAndOutput
(
relu_layer_float
,
"relu_after_fc_shuffle"
,
*
(
fc_after_reshape_float
->
getOutput
(
0
)),
{
output_name
},
test_mode
);
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_float
,
"relu_after_fc_shuffle"
,
{
output_name
},
test_mode
);
}
else
{
}
else
{
RreplenishLayerAndOutput
(
fc_after_reshape_float
,
"shuffle_after_fc"
,
RreplenishLayerAndOutput
(
fc_after_reshape_float
,
{
output_name
},
test_mode
);
"shuffle_after_fc"
,
{
output_name
},
test_mode
);
}
}
}
}
};
};
...
@@ -255,15 +323,20 @@ class FcOpConverter : public OpConverter {
...
@@ -255,15 +323,20 @@ class FcOpConverter : public OpConverter {
if
(
enable_int8
||
support_int8
)
{
if
(
enable_int8
||
support_int8
)
{
// add conv1x1 layer
// add conv1x1 layer
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
auto
*
fc_layer_int8
=
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
X
,
n_output
,
nv_ksize
,
Convolution
,
weight
.
get
(),
bias
.
get
());
*
X
,
n_output
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
if
(
activation_type
==
"relu"
)
{
if
(
activation_type
==
"relu"
)
{
fc_layer_int8
->
setName
(
fc_layer_int8
->
setName
(
(
"ernie_fc_op_int8: Convolution (Output: "
+
output_name
+
")"
)
(
"ernie_fc_op_int8: Convolution (Output: "
+
output_name
+
")"
)
.
c_str
());
.
c_str
());
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
"must have out threshold in fc layers in int8 mode"
));
float
out_scale
=
0
;
float
out_scale
=
0
;
...
@@ -275,15 +348,20 @@ class FcOpConverter : public OpConverter {
...
@@ -275,15 +348,20 @@ class FcOpConverter : public OpConverter {
}
}
engine_
->
SetTensorDynamicRange
(
fc_layer_int8
->
getOutput
(
0
),
engine_
->
SetTensorDynamicRange
(
fc_layer_int8
->
getOutput
(
0
),
out_scale
);
out_scale
);
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
engine_
,
Activation
,
*
(
fc_layer_int8
->
getOutput
(
0
)),
TRT_ENGINE_ADD_LAYER
(
engine_
,
nvinfer1
::
ActivationType
::
kRELU
);
Activation
,
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_ernie_fc_int8"
,
*
(
fc_layer_int8
->
getOutput
(
0
)),
{
output_name
},
test_mode
);
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_int8
,
"relu_after_ernie_fc_int8"
,
{
output_name
},
test_mode
);
}
else
{
}
else
{
RreplenishLayerAndOutput
(
fc_layer_int8
,
RreplenishLayerAndOutput
(
fc_layer_int8
,
"ernie_fc_op_int8: Convolution"
,
"ernie_fc_op_int8: Convolution"
,
{
output_name
},
test_mode
);
{
output_name
},
test_mode
);
}
}
}
else
{
}
else
{
// add fc layer
// add fc layer
...
@@ -292,25 +370,30 @@ class FcOpConverter : public OpConverter {
...
@@ -292,25 +370,30 @@ class FcOpConverter : public OpConverter {
if
(
activation_type
==
"relu"
)
{
if
(
activation_type
==
"relu"
)
{
fc_layer_float
->
setName
(
fc_layer_float
->
setName
(
(
"ernie_fc_op_float: (Output: "
+
output_name
+
")"
).
c_str
());
(
"ernie_fc_op_float: (Output: "
+
output_name
+
")"
).
c_str
());
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
TRT_ENGINE_ADD_LAYER
(
nvinfer1
::
IActivationLayer
*
relu_layer_float
=
engine_
,
Activation
,
*
(
fc_layer_float
->
getOutput
(
0
)),
TRT_ENGINE_ADD_LAYER
(
engine_
,
nvinfer1
::
ActivationType
::
kRELU
);
Activation
,
*
(
fc_layer_float
->
getOutput
(
0
)),
nvinfer1
::
ActivationType
::
kRELU
);
RreplenishLayerAndOutput
(
relu_layer_float
,
RreplenishLayerAndOutput
(
relu_layer_float
,
"relu_after_ernie_fc_float"
,
{
output_name
},
"relu_after_ernie_fc_float"
,
{
output_name
},
test_mode
);
test_mode
);
}
else
{
}
else
{
RreplenishLayerAndOutput
(
fc_layer_float
,
"ernie_fc_op_float"
,
RreplenishLayerAndOutput
(
{
output_name
},
test_mode
);
fc_layer_float
,
"ernie_fc_op_float"
,
{
output_name
},
test_mode
);
}
}
}
}
}
else
{
// need reshape input before and after fc
}
else
{
// need reshape input before and after fc
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
x_dim
.
nbDims
,
x_num_col_dims
,
x_dim
.
nbDims
,
x_num_col_dims
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Params and input dims mismatch. Paddle-TRT FC "
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d."
,
"x_dim.nbDims : %d, x_num_col_dims : %d."
,
x_dim
.
nbDims
,
x_num_col_dims
));
x_dim
.
nbDims
,
x_num_col_dims
));
auto
*
reshape_before_fc_layer
=
auto
*
reshape_before_fc_layer
=
reshape_before_fc
(
X
,
x_dim
,
x_num_col_dims
,
output_name
);
reshape_before_fc
(
X
,
x_dim
,
x_num_col_dims
,
output_name
);
auto
*
reshape_itensor
=
reshape_before_fc_layer
->
getOutput
(
0
);
auto
*
reshape_itensor
=
reshape_before_fc_layer
->
getOutput
(
0
);
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
245005d4
...
@@ -22,7 +22,8 @@ namespace tensorrt {
...
@@ -22,7 +22,8 @@ namespace tensorrt {
class
MultiheadMatMulOpConverter
:
public
OpConverter
{
class
MultiheadMatMulOpConverter
:
public
OpConverter
{
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid multihead_mamul op to a corresponding tensorrt "
VLOG
(
3
)
<<
"convert a fluid multihead_mamul op to a corresponding tensorrt "
"network structure"
;
"network structure"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
...
@@ -52,8 +53,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -52,8 +53,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
float
*
bias_data
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
);
float
*
bias_data
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
);
std
::
vector
<
float
>
weight_data_tmp
;
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
weight_t
->
numel
());
weight_data_tmp
.
reserve
(
weight_t
->
numel
());
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
memcpy
(
weight_t
->
numel
()
*
sizeof
(
float
));
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
// (hidden_in, 3, hidden_out)
// (hidden_in, 3, hidden_out)
auto
weight_dims
=
weight_t
->
dims
();
auto
weight_dims
=
weight_t
->
dims
();
...
@@ -98,14 +99,15 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -98,14 +99,15 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
float
dp_probs
=
1.0
/
127.0
;
float
dp_probs
=
1.0
/
127.0
;
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
nv_ksize
,
weight
,
bias
);
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
fc_layer
->
setName
(
fc_layer
->
setName
(
(
"Multihead: Convolution/FullyConnected: (Output: "
+
(
"Multihead: Convolution/FullyConnected: (Output: "
+
output_name
+
")"
)
output_name
+
")"
)
.
c_str
());
.
c_str
());
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"must have out_threshold in multihead layers in int8 mode"
));
"must have out_threshold in multihead layers in int8 mode"
));
float
out_scale
=
float
out_scale
=
...
@@ -119,13 +121,19 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -119,13 +121,19 @@ class MultiheadMatMulOpConverter : public OpConverter {
"CustomQKVToContextPluginDynamic"
,
"3"
);
"CustomQKVToContextPluginDynamic"
,
"3"
);
assert
(
creator
!=
nullptr
);
assert
(
creator
!=
nullptr
);
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
1
}};
if
(
qkv2context_plugin_int8
)
{
if
(
qkv2context_plugin_int8
)
{
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
fields
.
push_back
({
"dq_probs"
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
}
}
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
...
@@ -154,7 +162,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -154,7 +162,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
engine_
->
SetTensorDynamicRange
(
max_seqlen_tensor
,
1.0
f
);
engine_
->
SetTensorDynamicRange
(
max_seqlen_tensor
,
1.0
f
);
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
shape_dim
.
nbDims
=
1
;
...
@@ -173,8 +182,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -173,8 +182,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
// head_size,
// head_size,
// hidden_in]
// hidden_in]
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
three
,
auto
transpose_weight_v2
=
[](
const
float
*
src
,
int
head_number
,
int
head_size
,
float
*
dst
,
int
three
,
int
head_number
,
int
head_size
,
int
hidden_in
)
{
int
hidden_in
)
{
const
int
HH
=
head_size
*
hidden_in
;
const
int
HH
=
head_size
*
hidden_in
;
for
(
int
i
=
0
;
i
<
three
;
++
i
)
{
for
(
int
i
=
0
;
i
<
three
;
++
i
)
{
...
@@ -187,41 +199,47 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -187,41 +199,47 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
}
};
};
// [3, head_number, head_size] -> [head_number, 3, head_size]
// [3, head_number, head_size] -> [head_number, 3, head_size]
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
auto
transpose_bias_v2
=
int
H
)
{
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
dst
[
n
*
3
*
H
+
i
*
H
+
h
]
=
src
[
i
*
N
*
H
+
n
*
H
+
h
];
dst
[
n
*
3
*
H
+
i
*
H
+
h
]
=
src
[
i
*
N
*
H
+
n
*
H
+
h
];
}
}
}
}
}
};
}
memcpy
(
weight_data_tmp
.
data
(),
};
weight_data
,
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
weight_t
->
numel
()
*
sizeof
(
float
));
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
three
,
transpose_weight_v2
(
weight_data_tmp
.
data
(),
head_number
,
head_size
,
hidden_in
);
weight_data
,
three
,
head_number
,
head_size
,
hidden_in
);
std
::
vector
<
float
>
bias_data_tmp
;
std
::
vector
<
float
>
bias_data_tmp
;
bias_data_tmp
.
reserve
(
bias_t
->
numel
());
bias_data_tmp
.
reserve
(
bias_t
->
numel
());
memcpy
(
bias_data_tmp
.
data
(),
bias_data
,
memcpy
(
bias_t
->
numel
()
*
sizeof
(
float
));
bias_data_tmp
.
data
(),
bias_data
,
bias_t
->
numel
()
*
sizeof
(
float
));
transpose_bias_v2
(
bias_data_tmp
.
data
(),
bias_data
,
head_number
,
transpose_bias_v2
(
head_size
);
bias_data_tmp
.
data
(),
bias_data
,
head_number
,
head_size
);
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
float
dp_probs
=
1.0
/
127.0
;
float
dp_probs
=
1.0
/
127.0
;
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
nv_ksize
,
weight
,
bias
);
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
}
else
{
}
else
{
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
weight
,
bias
);
engine_
,
FullyConnected
,
*
input
,
n
,
weight
,
bias
);
}
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers "
"must have out threshold in multihead layers "
"in int8 mode"
));
"in int8 mode"
));
...
@@ -248,15 +266,21 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -248,15 +266,21 @@ class MultiheadMatMulOpConverter : public OpConverter {
int
var_seqlen
=
1
;
int
var_seqlen
=
1
;
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"hidden_size"
,
&
hidden_out
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
}};
1
}};
if
(
qkv2context_plugin_int8
)
{
if
(
qkv2context_plugin_int8
)
{
fields
.
push_back
({
"dq_probs"
,
&
dp_probs
,
fields
.
push_back
({
"dq_probs"
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
&
dp_probs
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
1
});
}
}
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
...
@@ -285,7 +309,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -285,7 +309,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto
max_seqlen_tensor
=
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
shape_dim
.
nbDims
=
1
;
...
@@ -301,7 +326,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -301,7 +326,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input
->
getDimensions
().
nbDims
,
3
,
input
->
getDimensions
().
nbDims
,
3
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The Input dim of the MultiheadMatMul should be 3, "
"The Input dim of the MultiheadMatMul should be 3, "
"but it's (%d) now."
,
"but it's (%d) now."
,
...
@@ -320,20 +346,25 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -320,20 +346,25 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast
<
size_t
>
(
bias_t
->
numel
())};
static_cast
<
size_t
>
(
bias_t
->
numel
())};
// add shuffle before fc
// add shuffle before fc
nvinfer1
::
Dims
reshape_before_fc_dim
;
std
::
vector
<
nvinfer1
::
ITensor
*>
reshape_before_fc_shape_tensor
;
reshape_before_fc_dim
.
nbDims
=
5
;
nvinfer1
::
ITensor
*
input_shape_tensor
=
Shape
(
input
);
reshape_before_fc_dim
.
d
[
0
]
=
0
;
reshape_before_fc_dim
.
d
[
1
]
=
0
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
reshape_before_fc_dim
.
d
[
2
]
=
0
;
reshape_before_fc_shape_tensor
.
push_back
(
Add1DConstantLayer
(
1
));
reshape_before_fc_dim
.
d
[
3
]
=
1
;
}
reshape_before_fc_dim
.
d
[
4
]
=
1
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
reshape_before_fc_shape_tensor
[
i
]
=
GetEleTensorOfShape
(
input_shape_tensor
,
i
);
}
auto
*
reshape_before_fc_layer
=
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
engine_
->
SetTensorDynamicRange
(
reshape_before_fc_layer
->
getOutput
(
0
),
engine_
->
SetTensorDynamicRange
(
reshape_before_fc_layer
->
getOutput
(
0
),
in_scale
);
in_scale
);
}
}
reshape_before_fc_layer
->
setReshapeDimensions
(
reshape_before_fc_dim
);
reshape_before_fc_layer
->
setInput
(
1
,
*
Concat
(
reshape_before_fc_shape_tensor
));
reshape_before_fc_layer
->
setName
(
reshape_before_fc_layer
->
setName
(
(
"shuffle_before_multihead_mamul(Output: "
+
output_name
+
")"
)
(
"shuffle_before_multihead_mamul(Output: "
+
output_name
+
")"
)
.
c_str
());
.
c_str
());
...
@@ -342,18 +373,28 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -342,18 +373,28 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
fc_layer
=
engine_
,
Convolution
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
TRT_ENGINE_ADD_LAYER
(
engine_
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
Convolution
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
nv_ksize
,
weight
.
get
(),
bias
.
get
());
}
else
{
}
else
{
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
fc_layer
=
engine_
,
FullyConnected
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
TRT_ENGINE_ADD_LAYER
(
engine_
,
n
,
weight
.
get
(),
bias
.
get
());
FullyConnected
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
weight
.
get
(),
bias
.
get
());
}
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers in int8 mode"
));
"must have out threshold in multihead layers in int8 mode"
));
float
out_scale
=
float
out_scale
=
...
@@ -380,8 +421,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -380,8 +421,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
with_fp16
=
true
;
with_fp16
=
true
;
}
}
plugin
::
DynamicPluginTensorRT
*
plugin
=
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden_in
,
head_number
,
new
plugin
::
QkvToContextPluginDynamic
(
head_size
,
scale
,
with_fp16
);
hidden_in
,
head_number
,
head_size
,
scale
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
2
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
2
,
plugin
);
}
}
}
else
{
}
else
{
...
@@ -391,8 +432,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -391,8 +432,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."
));
"the shape information to run the dynamic shape mode."
));
}
}
RreplenishLayerAndOutput
(
layer
,
"multihead_matmul"
,
{
output_name
},
RreplenishLayerAndOutput
(
test_mode
);
layer
,
"multihead_matmul"
,
{
output_name
},
test_mode
);
}
}
};
};
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
245005d4
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -14,7 +11,6 @@ limitations under the License. */
...
@@ -14,7 +11,6 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -23,7 +19,8 @@ namespace tensorrt {
...
@@ -23,7 +19,8 @@ namespace tensorrt {
class
SliceOpConverter
:
public
OpConverter
{
class
SliceOpConverter
:
public
OpConverter
{
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
// This OP is implemented by trt dynamic shpae plugin.
// This OP is implemented by trt dynamic shpae plugin.
// Dynamic shape plugin requires TRT version greater than 6.0.
// Dynamic shape plugin requires TRT version greater than 6.0.
VLOG
(
4
)
<<
"convert slice op to tensorrt layer"
;
VLOG
(
4
)
<<
"convert slice op to tensorrt layer"
;
...
@@ -64,63 +61,141 @@ class SliceOpConverter : public OpConverter {
...
@@ -64,63 +61,141 @@ class SliceOpConverter : public OpConverter {
}
}
ends
[
i
]
=
std
::
min
(
ends
[
i
],
input_dims
.
d
[
axes
[
i
]]);
ends
[
i
]
=
std
::
min
(
ends
[
i
],
input_dims
.
d
[
axes
[
i
]]);
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
ends
[
i
],
starts
[
i
],
ends
[
i
],
starts
[
i
],
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Attr(ends) should be greater than attr(starts) in "
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d."
,
"slice op. But received ends = %d, starts = %d."
,
ends
[
i
],
starts
[
i
]));
ends
[
i
],
starts
[
i
]));
}
}
}
}
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
use_oss
()
&&
engine_
->
with_ernie
()
&&
#if IS_TRT_VERSION_GE(6000)
input_dims
.
nbDims
==
4
)
{
auto
nchw_input_dims
=
input
->
getDimensions
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
nvinfer1
::
Dims
trt_start_dims
;
if
(
engine_
->
with_interleaved
())
{
trt_start_dims
.
nbDims
=
nchw_input_dims
.
nbDims
;
auto
*
shuffler_slice
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
memset
(
trt_start_dims
.
d
,
0
,
sizeof
(
int32_t
)
*
nchw_input_dims
.
nbDims
);
nvinfer1
::
Permutation
transpose_embed
{
2
,
1
,
0
,
3
};
nvinfer1
::
Dims
trt_size_dims
=
trt_start_dims
;
shuffler_slice
->
setSecondTranspose
(
transpose_embed
);
nvinfer1
::
Dims
trt_end_dims
=
trt_start_dims
;
engine_
->
SetTensorDynamicRange
(
shuffler_slice
->
getOutput
(
0
),
nvinfer1
::
Dims
trt_step_dims
=
trt_start_dims
;
out_scale
);
for
(
int
i
=
0
;
i
<
trt_step_dims
.
nbDims
;
i
++
)
trt_step_dims
.
d
[
i
]
=
1
;
shuffler_slice
->
setName
(
(
"SpecialSlice_interleaved: transpose: (Output: "
+
output_name
+
// input : [N,C,H,W]
")"
)
bool
has_neg_indices
=
false
;
.
c_str
());
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
plugin_inputs
.
emplace_back
(
shuffler_slice
->
getOutput
(
0
));
int
trt_axis
=
axes
[
i
];
trt_start_dims
.
d
[
trt_axis
]
=
starts
[
i
];
trt_end_dims
.
d
[
trt_axis
]
=
ends
[
i
];
if
(
starts
[
i
]
<
0
||
ends
[
i
]
<
0
)
has_neg_indices
=
true
;
}
auto
*
shape_tensor
=
Shape
(
input
);
auto
*
start_tensor
=
Add1DConstantLayer
(
trt_start_dims
);
if
(
has_neg_indices
)
{
start_tensor
=
FixNegIndices
(
shape_tensor
,
start_tensor
);
}
std
::
vector
<
nvinfer1
::
ITensor
*>
end_vec_tensor
;
for
(
int
i
=
0
;
i
<
trt_end_dims
.
nbDims
;
i
++
)
{
end_vec_tensor
.
push_back
(
GetEleTensorOfShape
(
shape_tensor
,
i
));
}
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
trt_axis
=
axes
[
i
];
if
(
ends
[
i
]
>=
0
)
{
end_vec_tensor
[
trt_axis
]
=
Add1DConstantLayer
(
ends
[
i
]);
}
else
{
}
else
{
plugin_inputs
.
emplace_back
(
input
);
end_vec_tensor
[
trt_axis
]
=
Sum
(
end_vec_tensor
[
trt_axis
],
Add1DConstantLayer
(
ends
[
i
]));
}
}
std
::
string
pos_name
;
}
if
(
engine_
->
Has
(
"ernie_pos_name"
))
{
pos_name
=
engine_
->
Get
<
std
::
string
>
(
"ernie_pos_name"
);
// CI failed in trt 6015 but success in 7134, may be a trt bug
}
else
{
#if IS_TRT_VERSION_GE(7134)
// hard code for compatibility
auto
*
size_tensor
=
pos_name
=
engine_
->
network
()
->
getInput
(
2
)
->
getName
();
Sub
(
Min
(
Concat
(
end_vec_tensor
),
shape_tensor
),
start_tensor
);
#else
auto
*
size_tensor
=
Sub
(
Concat
(
end_vec_tensor
),
start_tensor
);
#endif
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
*
input
,
trt_start_dims
,
trt_size_dims
,
trt_step_dims
);
layer
->
setInput
(
1
,
*
start_tensor
);
layer
->
setInput
(
2
,
*
size_tensor
);
if
(
decrease_axises
.
size
()
>
0
)
{
std
::
vector
<
int32_t
>
gather_indices
;
for
(
int
i
=
0
;
i
<
trt_size_dims
.
nbDims
;
i
++
)
{
if
(
decrease_axises
.
end
()
!=
std
::
find
(
decrease_axises
.
begin
(),
decrease_axises
.
end
(),
i
))
continue
;
gather_indices
.
push_back
(
i
);
}
}
plugin_inputs
.
emplace_back
(
if
(
gather_indices
.
empty
())
engine_
->
GetITensor
(
pos_name
));
// cu_seqlens, eval_placeholder_2
gather_indices
.
push_back
(
decrease_axises
[
0
]);
auto
real_size_tensor
=
Gather
(
size_tensor
,
gather_indices
);
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
layer
->
getOutput
(
0
));
plugin
::
SpecialSlicePluginDynamic
*
plugin
=
layer
->
setInput
(
1
,
*
real_size_tensor
);
new
plugin
::
SpecialSlicePluginDynamic
();
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
plugin
);
}
else
{
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
int
decrease_axis
=
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
}
}
#else
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
int
decrease_axis
=
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
#endif
}
else
{
}
else
{
#if IS_TRT_VERSION_GE(6000)
auto
chw_input_dims
=
input
->
getDimensions
();
nvinfer1
::
Dims
trt_start_dims
;
trt_start_dims
.
nbDims
=
chw_input_dims
.
nbDims
;
memset
(
trt_start_dims
.
d
,
0
,
sizeof
(
int32_t
)
*
chw_input_dims
.
nbDims
);
nvinfer1
::
Dims
trt_size_dims
=
chw_input_dims
;
nvinfer1
::
Dims
trt_step_dims
;
trt_step_dims
.
nbDims
=
chw_input_dims
.
nbDims
;
for
(
int
i
=
0
;
i
<
trt_step_dims
.
nbDims
;
i
++
)
trt_step_dims
.
d
[
i
]
=
1
;
// input : [C,H,W]
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
trt_axis
=
axes
[
i
]
-
1
;
trt_start_dims
.
d
[
trt_axis
]
=
starts
[
i
];
trt_size_dims
.
d
[
trt_axis
]
=
ends
[
i
]
-
starts
[
i
];
}
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
*
input
,
trt_start_dims
,
trt_size_dims
,
trt_step_dims
);
nvinfer1
::
Dims
real_trt_size_dims
;
real_trt_size_dims
.
nbDims
=
0
;
if
(
decrease_axises
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
decrease_axises
.
size
();
i
++
)
{
decrease_axises
[
i
]
--
;
}
for
(
int
i
=
0
;
i
<
trt_size_dims
.
nbDims
;
i
++
)
{
if
(
decrease_axises
.
end
()
!=
std
::
find
(
decrease_axises
.
begin
(),
decrease_axises
.
end
(),
i
))
continue
;
real_trt_size_dims
.
d
[
real_trt_size_dims
.
nbDims
]
=
trt_size_dims
.
d
[
i
];
real_trt_size_dims
.
nbDims
++
;
}
if
(
real_trt_size_dims
.
nbDims
==
0
)
{
real_trt_size_dims
.
nbDims
=
1
;
real_trt_size_dims
.
d
[
0
]
=
1
;
}
auto
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
layer
->
getOutput
(
0
));
reshape_layer
->
setReshapeDimensions
(
real_trt_size_dims
);
layer
=
static_cast
<
nvinfer1
::
ILayer
*>
(
reshape_layer
);
}
#else
bool
with_fp16
=
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePlugin
*
plugin
=
plugin
::
SlicePlugin
*
plugin
=
new
plugin
::
SlicePlugin
(
starts
,
ends
,
axes
,
with_fp16
);
new
plugin
::
SlicePlugin
(
starts
,
ends
,
axes
,
with_fp16
);
layer
=
engine_
->
AddPlugin
(
&
input
,
1
,
plugin
);
layer
=
engine_
->
AddPlugin
(
&
input
,
1
,
plugin
);
#endif
}
}
RreplenishLayerAndOutput
(
layer
,
"slice"
,
{
output_name
},
test_mode
);
RreplenishLayerAndOutput
(
layer
,
"slice"
,
{
output_name
},
test_mode
);
}
}
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
245005d4
...
@@ -48,7 +48,8 @@ void TensorRTEngine::InitNetwork() {
...
@@ -48,7 +48,8 @@ void TensorRTEngine::InitNetwork() {
optim_profiles_
[
i
]
=
infer_builder_
->
createOptimizationProfile
();
optim_profiles_
[
i
]
=
infer_builder_
->
createOptimizationProfile
();
}
}
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
*
buffers
,
void
TensorRTEngine
::
Execute
(
int
batch_size
,
std
::
vector
<
void
*>
*
buffers
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
freshDeviceId
();
freshDeviceId
();
auto
infer_context
=
context
();
auto
infer_context
=
context
();
...
@@ -126,14 +127,32 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -126,14 +127,32 @@ void TensorRTEngine::FreezeNetwork() {
}
}
#if IS_TRT_VERSION_GE(5122)
#if IS_TRT_VERSION_GE(5122)
auto
is_layer_int8
=
[
&
](
nvinfer1
::
ILayer
*
layer
)
->
bool
{
auto
layer_int8_fallback
=
[
&
](
nvinfer1
::
ILayer
*
layer
)
->
bool
{
if
(
layer
->
getType
()
==
nvinfer1
::
LayerType
::
kSHAPE
)
{
return
false
;
}
bool
all_int
=
true
;
for
(
int
j
=
0
;
j
<
layer
->
getNbInputs
();
j
++
)
{
auto
*
temp_in
=
layer
->
getInput
(
j
);
if
(
temp_in
->
getType
()
!=
nvinfer1
::
DataType
::
kINT32
)
{
all_int
=
false
;
}
}
for
(
int
j
=
0
;
j
<
layer
->
getNbOutputs
();
j
++
)
{
auto
*
temp_out
=
layer
->
getOutput
(
j
);
if
(
temp_out
->
getType
()
!=
nvinfer1
::
DataType
::
kINT32
)
{
all_int
=
false
;
}
}
if
(
all_int
)
return
false
;
for
(
int
j
=
0
;
j
<
layer
->
getNbInputs
();
j
++
)
{
for
(
int
j
=
0
;
j
<
layer
->
getNbInputs
();
j
++
)
{
auto
*
temp_in
=
layer
->
getInput
(
j
);
auto
*
temp_in
=
layer
->
getInput
(
j
);
if
(
!
temp_in
->
dynamicRangeIsSet
())
{
if
(
!
temp_in
->
dynamicRangeIsSet
())
{
VLOG
(
1
)
<<
"Layer(Name: "
<<
layer
->
getName
()
VLOG
(
1
)
<<
"Layer(Name: "
<<
layer
->
getName
()
<<
") is set to float32 because its input("
<<
") is set to float32 because its input("
<<
temp_in
->
getName
()
<<
") doesn't have dynamic range."
;
<<
temp_in
->
getName
()
<<
") doesn't have dynamic range."
;
return
fals
e
;
return
tru
e
;
}
}
}
}
for
(
int
j
=
0
;
j
<
layer
->
getNbOutputs
();
j
++
)
{
for
(
int
j
=
0
;
j
<
layer
->
getNbOutputs
();
j
++
)
{
...
@@ -142,10 +161,10 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -142,10 +161,10 @@ void TensorRTEngine::FreezeNetwork() {
VLOG
(
1
)
<<
"Layer(Name: "
<<
layer
->
getName
()
VLOG
(
1
)
<<
"Layer(Name: "
<<
layer
->
getName
()
<<
") is set to float32 because its output("
<<
") is set to float32 because its output("
<<
temp_out
->
getName
()
<<
") doesn't have dynamic range."
;
<<
temp_out
->
getName
()
<<
") doesn't have dynamic range."
;
return
fals
e
;
return
tru
e
;
}
}
}
}
return
tru
e
;
return
fals
e
;
};
};
// If a layer's output is the network's output, or not all of its inputs
// If a layer's output is the network's output, or not all of its inputs
// and outputs have scales,
// and outputs have scales,
...
@@ -154,7 +173,7 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -154,7 +173,7 @@ void TensorRTEngine::FreezeNetwork() {
int
layers_no_int8
=
0
;
int
layers_no_int8
=
0
;
for
(
int
i
=
0
;
i
<
network
()
->
getNbLayers
();
i
++
)
{
for
(
int
i
=
0
;
i
<
network
()
->
getNbLayers
();
i
++
)
{
auto
layer
=
network
()
->
getLayer
(
i
);
auto
layer
=
network
()
->
getLayer
(
i
);
if
(
!
is_layer_int8
(
layer
))
{
if
(
layer_int8_fallback
(
layer
))
{
layer
->
setPrecision
(
nvinfer1
::
DataType
::
kFLOAT
);
layer
->
setPrecision
(
nvinfer1
::
DataType
::
kFLOAT
);
++
layers_no_int8
;
++
layers_no_int8
;
}
}
...
@@ -205,7 +224,8 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -205,7 +224,8 @@ void TensorRTEngine::FreezeNetwork() {
for
(
auto
&
input
:
min_input_shape_
)
{
for
(
auto
&
input
:
min_input_shape_
)
{
#if IS_TRT_VERSION_LT(7000)
#if IS_TRT_VERSION_LT(7000)
// trt6 will check all_of input > 0
// trt6 will check all_of input > 0
if
(
!
(
std
::
all_of
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
if
(
!
(
std
::
all_of
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
[](
int
x
)
{
return
x
>
0
;
})
&&
[](
int
x
)
{
return
x
>
0
;
})
&&
std
::
all_of
(
max_input_shape_
[
input
.
first
].
begin
(),
std
::
all_of
(
max_input_shape_
[
input
.
first
].
begin
(),
max_input_shape_
[
input
.
first
].
end
(),
max_input_shape_
[
input
.
first
].
end
(),
...
@@ -222,13 +242,16 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -222,13 +242,16 @@ void TensorRTEngine::FreezeNetwork() {
<<
", opt: "
<<
Vec2Str
(
optim_input_shape_
[
input
.
first
]);
<<
", opt: "
<<
Vec2Str
(
optim_input_shape_
[
input
.
first
]);
optim_profiles_
[
i
]
->
setDimensions
(
optim_profiles_
[
i
]
->
setDimensions
(
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMIN
,
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMIN
,
Vec2TRT_Dims
(
input
.
second
,
input
.
first
,
true
));
Vec2TRT_Dims
(
input
.
second
,
input
.
first
,
true
));
optim_profiles_
[
i
]
->
setDimensions
(
optim_profiles_
[
i
]
->
setDimensions
(
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMAX
,
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kMAX
,
Vec2TRT_Dims
(
max_input_shape_
[
input
.
first
],
input
.
first
,
true
));
Vec2TRT_Dims
(
max_input_shape_
[
input
.
first
],
input
.
first
,
true
));
optim_profiles_
[
i
]
->
setDimensions
(
optim_profiles_
[
i
]
->
setDimensions
(
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kOPT
,
input
.
first
.
c_str
(),
nvinfer1
::
OptProfileSelector
::
kOPT
,
Vec2TRT_Dims
(
optim_input_shape_
[
input
.
first
],
input
.
first
,
true
));
Vec2TRT_Dims
(
optim_input_shape_
[
input
.
first
],
input
.
first
,
true
));
}
}
infer_builder_config_
->
addOptimizationProfile
(
optim_profiles_
[
i
]);
infer_builder_config_
->
addOptimizationProfile
(
optim_profiles_
[
i
]);
...
@@ -262,9 +285,10 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -262,9 +285,10 @@ void TensorRTEngine::FreezeNetwork() {
#endif
#endif
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
infer_engine_
,
platform
::
errors
::
Fatal
(
infer_engine_
,
"Build TensorRT cuda engine failed! Please recheck "
platform
::
errors
::
Fatal
(
"you configurations related to paddle-TensorRT."
));
"Build TensorRT cuda engine failed! Please recheck "
"you configurations related to paddle-TensorRT."
));
binding_num_
=
infer_engine_
->
getNbBindings
();
binding_num_
=
infer_engine_
->
getNbBindings
();
// reset status for dynamic shape clone
// reset status for dynamic shape clone
...
@@ -279,16 +303,19 @@ void TensorRTEngine::FreezeNetwork() {
...
@@ -279,16 +303,19 @@ void TensorRTEngine::FreezeNetwork() {
nvinfer1
::
ITensor
*
TensorRTEngine
::
DeclareInput
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
TensorRTEngine
::
DeclareInput
(
const
std
::
string
&
name
,
nvinfer1
::
DataType
dtype
,
nvinfer1
::
DataType
dtype
,
const
nvinfer1
::
Dims
&
dims
)
{
const
nvinfer1
::
Dims
&
dims
)
{
PADDLE_ENFORCE_EQ
(
network
()
!=
nullptr
,
true
,
PADDLE_ENFORCE_EQ
(
network
()
!=
nullptr
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The TRT network should be initialized first."
));
"The TRT network should be initialized first."
));
auto
*
input
=
network
()
->
addInput
(
name
.
c_str
(),
dtype
,
dims
);
auto
*
input
=
network
()
->
addInput
(
name
.
c_str
(),
dtype
,
dims
);
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
input
,
platform
::
errors
::
InvalidArgument
(
"Adding input %s failed in "
input
,
"TensorRT inference network. "
platform
::
errors
::
InvalidArgument
(
"Adding input %s failed in "
"Please recheck your input."
,
"TensorRT inference network. "
name
));
"Please recheck your input."
,
PADDLE_ENFORCE_EQ
(
input
->
isNetworkInput
(),
true
,
name
));
PADDLE_ENFORCE_EQ
(
input
->
isNetworkInput
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Input %s is not the input of TRT inference network. "
"Input %s is not the input of TRT inference network. "
"Please recheck your input."
,
"Please recheck your input."
,
...
@@ -297,22 +324,26 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
...
@@ -297,22 +324,26 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
return
input
;
return
input
;
}
}
void
TensorRTEngine
::
DeclareOutput
(
const
nvinfer1
::
ILayer
*
layer
,
int
offset
,
void
TensorRTEngine
::
DeclareOutput
(
const
nvinfer1
::
ILayer
*
layer
,
int
offset
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
auto
*
output
=
layer
->
getOutput
(
offset
);
auto
*
output
=
layer
->
getOutput
(
offset
);
SetITensor
(
name
,
output
);
SetITensor
(
name
,
output
);
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
output
,
platform
::
errors
::
InvalidArgument
(
output
,
"The output %s of TRT engine should not be null."
,
name
));
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be null."
,
name
));
output
->
setName
(
name
.
c_str
());
output
->
setName
(
name
.
c_str
());
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be the input "
"The output %s of TRT engine should not be the input "
"of the network at the same time."
,
"of the network at the same time."
,
name
));
name
));
network
()
->
markOutput
(
*
output
);
network
()
->
markOutput
(
*
output
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
output
->
isNetworkOutput
(),
true
,
output
->
isNetworkOutput
(),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should be the output of the network."
,
"The output %s of TRT engine should be the output of the network."
,
name
));
name
));
...
@@ -321,10 +352,12 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
...
@@ -321,10 +352,12 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
void
TensorRTEngine
::
DeclareOutput
(
const
std
::
string
&
name
)
{
void
TensorRTEngine
::
DeclareOutput
(
const
std
::
string
&
name
)
{
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
auto
*
output
=
TensorRTEngine
::
GetITensor
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
output
,
platform
::
errors
::
InvalidArgument
(
output
,
"The output %s of TRT engine should not be null."
,
name
));
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be null."
,
name
));
output
->
setName
(
name
.
c_str
());
output
->
setName
(
name
.
c_str
());
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
PADDLE_ENFORCE_EQ
(
output
->
isNetworkInput
(),
false
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The output %s of TRT engine should not be the input "
"The output %s of TRT engine should not be the input "
"of the network at the same time."
,
"of the network at the same time."
,
...
@@ -335,17 +368,20 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
...
@@ -335,17 +368,20 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
void
TensorRTEngine
::
SetITensor
(
const
std
::
string
&
name
,
void
TensorRTEngine
::
SetITensor
(
const
std
::
string
&
name
,
nvinfer1
::
ITensor
*
tensor
)
{
nvinfer1
::
ITensor
*
tensor
)
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
tensor
,
platform
::
errors
::
InvalidArgument
(
tensor
,
"Tensor named %s of TRT engine should not be null."
,
name
));
platform
::
errors
::
InvalidArgument
(
"Tensor named %s of TRT engine should not be null."
,
name
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
0
,
itensor_map_
.
count
(
name
),
0
,
itensor_map_
.
count
(
name
),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Tensor named %s of TRT engine should not be duplicated"
,
name
));
"Tensor named %s of TRT engine should not be duplicated"
,
name
));
itensor_map_
[
name
]
=
tensor
;
itensor_map_
[
name
]
=
tensor
;
}
}
nvinfer1
::
ITensor
*
TensorRTEngine
::
GetITensor
(
const
std
::
string
&
name
)
{
nvinfer1
::
ITensor
*
TensorRTEngine
::
GetITensor
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
itensor_map_
.
count
(
name
),
true
,
PADDLE_ENFORCE_EQ
(
itensor_map_
.
count
(
name
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Tensor named %s is not found in TRT engine"
,
name
));
"Tensor named %s is not found in TRT engine"
,
name
));
return
itensor_map_
[
name
];
return
itensor_map_
[
name
];
...
@@ -362,15 +398,16 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
...
@@ -362,15 +398,16 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
std
::
string
splitter
=
"__"
;
std
::
string
splitter
=
"__"
;
std
::
string
name_with_suffix
=
name
+
splitter
+
name_suffix
;
std
::
string
name_with_suffix
=
name
+
splitter
+
name_suffix
;
platform
::
CPUPlace
cpu_place
;
platform
::
CPUPlace
cpu_place
;
PADDLE_ENFORCE_EQ
(
weight_map
.
count
(
name_with_suffix
),
0
,
PADDLE_ENFORCE_EQ
(
weight_map
.
count
(
name_with_suffix
),
0
,
platform
::
errors
::
AlreadyExists
(
platform
::
errors
::
AlreadyExists
(
"The weight named %s is set into the weight map "
"The weight named %s is set into the weight map "
"twice in TRT OP converter."
,
"twice in TRT OP converter."
,
name_with_suffix
));
name_with_suffix
));
weight_map
[
name_with_suffix
].
reset
(
new
framework
::
Tensor
());
weight_map
[
name_with_suffix
].
reset
(
new
framework
::
Tensor
());
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
->
dims
());
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
->
dims
());
paddle
::
framework
::
TensorCopySync
(
*
weight_tensor
,
cpu_place
,
paddle
::
framework
::
TensorCopySync
(
weight_map
[
name_with_suffix
].
get
());
*
weight_tensor
,
cpu_place
,
weight_map
[
name_with_suffix
].
get
());
float
*
weight_data
=
float
*
weight_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float
>
(
cpu_place
);
weight_map
[
name_with_suffix
]
->
mutable_data
<
float
>
(
cpu_place
);
name_suffix_counter
+=
1
;
name_suffix_counter
+=
1
;
...
@@ -380,21 +417,24 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
...
@@ -380,21 +417,24 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
int
TensorRTEngine
::
GetRuntimeBatch
()
{
return
runtime_batch_
;
}
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPlugin
(
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRT
*
plugin
)
{
plugin
::
PluginTensorRT
*
plugin
)
{
owned_plugin_
.
emplace_back
(
plugin
);
owned_plugin_
.
emplace_back
(
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
}
}
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPluginV2Ext
(
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPluginV2Ext
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRTV2Ext
*
plugin
)
{
plugin
::
PluginTensorRTV2Ext
*
plugin
)
{
owned_plugin_v2ext_
.
emplace_back
(
plugin
);
owned_plugin_v2ext_
.
emplace_back
(
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
}
}
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPluginV2IOExt
(
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPluginV2IOExt
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
nvinfer1
::
IPluginV2IOExt
*
plugin
)
{
nvinfer1
::
IPluginV2IOExt
*
plugin
)
{
owned_plugin_v2ioext_
.
emplace_back
(
plugin
);
owned_plugin_v2ioext_
.
emplace_back
(
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
...
@@ -403,10 +443,12 @@ nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt(
...
@@ -403,10 +443,12 @@ nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt(
void
TensorRTEngine
::
freshDeviceId
()
{
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
int
count
;
cudaGetDeviceCount
(
&
count
);
cudaGetDeviceCount
(
&
count
);
PADDLE_ENFORCE_LT
(
device_id_
,
count
,
PADDLE_ENFORCE_LT
(
device_id_
,
count
,
platform
::
errors
::
OutOfRange
(
platform
::
errors
::
OutOfRange
(
"Device id %d exceeds the current device count: %d."
,
"Device id %d exceeds the current device count: %d."
,
device_id_
,
count
));
device_id_
,
count
));
platform
::
SetDeviceId
(
device_id_
);
platform
::
SetDeviceId
(
device_id_
);
}
}
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
245005d4
...
@@ -1063,14 +1063,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
...
@@ -1063,14 +1063,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
if
(
desc
.
HasAttr
(
"decrease_axis"
))
{
if
(
desc
.
HasAttr
(
"decrease_axis"
))
{
std
::
vector
<
int
>
decrease_axis
=
std
::
vector
<
int
>
decrease_axis
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"decrease_axis"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"decrease_axis"
));
if
(
with_dynamic_shape
)
{
if
(
!
with_dynamic_shape
)
{
if
(
decrease_axis
.
size
()
>
1
)
{
if
(
decrease_axis
.
end
()
!=
return
false
;
std
::
find
(
decrease_axis
.
begin
(),
decrease_axis
.
end
(),
0
))
{
}
}
else
{
if
(
decrease_axis
.
size
()
>
0
)
{
VLOG
(
3
)
<<
"Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT"
;
return
false
;
return
false
;
}
}
}
}
...
@@ -1102,15 +1097,28 @@ bool OpTeller::Tell(const framework::ir::Node* node,
...
@@ -1102,15 +1097,28 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return
false
;
return
false
;
}
}
}
}
}
else
{
}
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
}
if
(
starts
[
i
]
<
0
||
ends
[
i
]
<
0
)
{
// not support following four inputs for slice in paddle-trt
VLOG
(
3
)
<<
"Invalid slice attribute 'starts' or 'ends'. "
auto
slice_inputs
=
desc
.
Inputs
();
// its size == 5
"Negative starts or ends not supported in TensorRT "
if
(
slice_inputs
.
find
(
"StartsTensor"
)
!=
slice_inputs
.
end
())
{
"when running in dynamic shape mode."
;
if
(
desc
.
Input
(
"StartsTensor"
).
size
())
{
return
false
;
return
false
;
}
}
}
}
if
(
slice_inputs
.
find
(
"EndsTensor"
)
!=
slice_inputs
.
end
())
{
if
(
desc
.
Input
(
"EndsTensor"
).
size
())
{
return
false
;
}
}
if
(
slice_inputs
.
find
(
"StartsTensorList"
)
!=
slice_inputs
.
end
())
{
if
(
desc
.
Input
(
"StartsTensorList"
).
size
())
{
return
false
;
}
}
if
(
slice_inputs
.
find
(
"EndsTensorList"
)
!=
slice_inputs
.
end
())
{
if
(
desc
.
Input
(
"EndsTensorList"
).
size
())
{
return
false
;
}
}
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
浏览文件 @
245005d4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -22,17 +22,14 @@ import unittest
...
@@ -22,17 +22,14 @@ import unittest
class
TrtConvertSliceTest
(
TrtLayerAutoScanTest
):
class
TrtConvertSliceTest
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
inputs
=
program_config
.
inputs
inputs
=
program_config
.
inputs
weights
=
program_config
.
weights
weights
=
program_config
.
weights
attrs
=
[
attrs
=
[
program_config
.
ops
[
i
].
attrs
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
for
i
in
range
(
len
(
program_config
.
ops
))
]
]
out_shape
=
list
(
inputs
[
'input_data'
].
shape
)
for
x
in
attrs
[
0
][
"decrease_axis"
]:
if
x
<
0
:
return
False
for
x
in
range
(
len
(
attrs
[
0
][
"axes"
])):
for
x
in
range
(
len
(
attrs
[
0
][
"axes"
])):
start
=
0
start
=
0
end
=
0
end
=
0
...
@@ -42,24 +39,30 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -42,24 +39,30 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
else
:
else
:
start
=
attrs
[
0
][
"starts"
][
x
]
start
=
attrs
[
0
][
"starts"
][
x
]
if
attrs
[
0
][
"ends"
][
x
]
<
0
:
if
attrs
[
0
][
"ends"
][
x
]
<
0
:
end
=
attrs
[
0
][
"ends"
][
x
]
+
inputs
[
'input_data'
].
shape
[
attrs
[
0
][
end
=
attrs
[
0
][
"ends"
][
x
]
+
inputs
[
'input_data'
].
shape
[
"axes"
][
x
]]
attrs
[
0
][
"axes"
][
x
]]
else
:
else
:
end
=
attrs
[
0
][
"ends"
][
x
]
end
=
attrs
[
0
][
"ends"
][
x
]
start
=
max
(
0
,
start
)
start
=
max
(
0
,
start
)
end
=
max
(
0
,
end
)
end
=
max
(
0
,
end
)
out_shape
[
attrs
[
0
][
"axes"
][
x
]]
=
end
-
start
if
start
>=
end
:
if
start
>=
end
:
return
False
return
False
for
x
in
attrs
[
0
][
"decrease_axis"
]:
if
x
<
0
:
return
False
if
(
out_shape
[
x
]
!=
1
):
return
False
return
True
return
True
def
sample_program_configs
(
self
):
def
sample_program_configs
(
self
):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
ones
([
6
,
6
,
64
,
64
]).
astype
(
np
.
float32
)
return
np
.
random
.
random
([
6
,
6
,
64
,
64
]).
astype
(
np
.
float32
)
for
axes
in
[[
0
,
1
],
[
1
,
3
],
[
2
,
3
]]:
for
axes
in
[[
0
,
1
],
[
1
,
3
],
[
2
,
3
]]:
for
starts
in
[[
0
,
1
]]:
for
starts
in
[[
0
,
1
]]:
for
ends
in
[[
2
,
2
],
[
5
,
5
]]:
for
ends
in
[[
2
,
2
],
[
5
,
5
]
,
[
1
,
-
1
]
]:
for
decrease_axis
in
[[],
[
1
],
[
2
],
[
-
1
],
[
-
100
]]:
for
decrease_axis
in
[[],
[
1
],
[
2
],
[
-
1
],
[
-
100
]]:
for
infer_flags
in
[[
-
1
]]:
for
infer_flags
in
[[
-
1
]]:
dics
=
[{
dics
=
[{
...
@@ -86,8 +89,9 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -86,8 +89,9 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
ops
=
ops
,
ops
=
ops
,
weights
=
{},
weights
=
{},
inputs
=
{
inputs
=
{
"input_data"
:
TensorConfig
(
data_gen
=
partial
(
"input_data"
:
generate_input1
,
dics
))
TensorConfig
(
data_gen
=
partial
(
generate_input1
,
dics
))
},
},
outputs
=
[
"slice_output_data"
])
outputs
=
[
"slice_output_data"
])
...
@@ -95,6 +99,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -95,6 +99,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def
sample_predictor_configs
(
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
]}
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
8
,
8
,
64
,
64
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
8
,
8
,
64
,
64
]}
...
@@ -106,17 +111,6 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -106,17 +111,6 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
self
.
dynamic_shape
.
opt_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
inputs
=
program_config
.
inputs
if
dynamic_shape
==
True
and
len
(
attrs
[
0
][
"decrease_axis"
])
==
0
:
return
1
,
2
if
dynamic_shape
==
True
and
len
(
attrs
[
0
][
"decrease_axis"
])
!=
1
:
return
0
,
3
if
dynamic_shape
==
False
and
len
(
attrs
[
0
][
"decrease_axis"
])
!=
0
:
return
0
,
3
if
dynamic_shape
:
for
i
in
range
(
len
(
attrs
[
0
][
"starts"
])):
if
attrs
[
0
][
"starts"
][
i
]
<
0
or
attrs
[
0
][
"ends"
][
i
]
<
0
:
return
0
,
3
if
not
dynamic_shape
:
if
not
dynamic_shape
:
for
x
in
attrs
[
0
][
"axes"
]:
for
x
in
attrs
[
0
][
"axes"
]:
if
x
==
0
:
if
x
==
0
:
...
@@ -124,8 +118,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -124,8 +118,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
return
1
,
2
return
1
,
2
attrs
=
[
attrs
=
[
program_config
.
ops
[
i
].
attrs
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
for
i
in
range
(
len
(
program_config
.
ops
))
]
]
self
.
trt_param
.
max_batch_size
=
9
self
.
trt_param
.
max_batch_size
=
9
# for static_shape
# for static_shape
...
@@ -140,11 +133,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -140,11 +133,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
# for dynamic_shape
# for dynamic_shape
generate_dynamic_shape
(
attrs
)
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
True
),
1e-5
attrs
,
True
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
True
),
1e-4
attrs
,
True
),
1e-4
def
test
(
self
):
def
test
(
self
):
# TODO(inference): fix.
# TODO(inference): fix.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录