未验证 提交 4c38b87e 编写于 作者: G gem5 提交者: GitHub

add some compare and logical trt converter (#48592)

上级 fcf26279
...@@ -2238,6 +2238,12 @@ USE_TRT_CONVERTER(elementwise_max_tensor); ...@@ -2238,6 +2238,12 @@ USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor); USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(elementwise_floordiv_tensor); USE_TRT_CONVERTER(elementwise_floordiv_tensor);
USE_TRT_CONVERTER(less_than);
USE_TRT_CONVERTER(greater_than);
USE_TRT_CONVERTER(logical_or);
USE_TRT_CONVERTER(logical_xor);
USE_TRT_CONVERTER(logical_and);
USE_TRT_CONVERTER(less_equal);
USE_TRT_CONVERTER(transpose); USE_TRT_CONVERTER(transpose);
USE_TRT_CONVERTER(transpose2); USE_TRT_CONVERTER(transpose2);
USE_TRT_CONVERTER(flatten); USE_TRT_CONVERTER(flatten);
......
...@@ -74,8 +74,12 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -74,8 +74,12 @@ class ElementwiseTensorOpConverter : public OpConverter {
nvinfer1::Dims dims_y = Y->getDimensions(); nvinfer1::Dims dims_y = Y->getDimensions();
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
int axis = -1;
// axis here is relative to explicit batch // axis here is relative to explicit batch
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis")); if (op_type_ != "logical_or" && op_type_ != "logical_xor" &&
op_type_ != "logical_and") {
axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
}
int real_x_rank = dims_x.nbDims; int real_x_rank = dims_x.nbDims;
int real_y_rank = dims_y.nbDims; int real_y_rank = dims_y.nbDims;
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
...@@ -139,17 +143,40 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -139,17 +143,40 @@ class ElementwiseTensorOpConverter : public OpConverter {
X = tmp; X = tmp;
} }
auto op_pair = ops.find(op_type_); if (op_type_ == "less_equal") {
PADDLE_ENFORCE_NE(op_pair, auto* less_layer =
ops.end(), TRT_ENGINE_ADD_LAYER(engine_,
platform::errors::InvalidArgument( ElementWise,
"Elementwise op's type(%s) is not supported. Please " *X,
"check if the op_type is correct.", *reshape_y_tensor,
op_type_)); nvinfer1::ElementWiseOperation::kLESS);
auto* equal_layer =
auto* layer = TRT_ENGINE_ADD_LAYER( TRT_ENGINE_ADD_LAYER(engine_,
engine_, ElementWise, *X, *reshape_y_tensor, op_pair->second); ElementWise,
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); *X,
*reshape_y_tensor,
nvinfer1::ElementWiseOperation::kEQUAL);
auto* layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*(less_layer->getOutput(0)),
*(equal_layer->getOutput(0)),
nvinfer1::ElementWiseOperation::kOR);
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
} else {
auto op_pair = ops.find(op_type_);
PADDLE_ENFORCE_NE(
op_pair,
ops.end(),
platform::errors::InvalidArgument(
"Elementwise op's type(%s) is not supported. Please "
"check if the op_type is correct.",
op_type_));
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *X, *reshape_y_tensor, op_pair->second);
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
}
} }
protected: protected:
...@@ -168,6 +195,11 @@ const std::unordered_map<std::string, nvinfer1::ElementWiseOperation> ...@@ -168,6 +195,11 @@ const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
{"pow", nvinfer1::ElementWiseOperation::kPOW}, {"pow", nvinfer1::ElementWiseOperation::kPOW},
{"max", nvinfer1::ElementWiseOperation::kMAX}, {"max", nvinfer1::ElementWiseOperation::kMAX},
{"floordiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV}, {"floordiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
{"less_than", nvinfer1::ElementWiseOperation::kLESS},
{"greater_than", nvinfer1::ElementWiseOperation::kGREATER},
{"logical_or", nvinfer1::ElementWiseOperation::kOR},
{"logical_xor", nvinfer1::ElementWiseOperation::kXOR},
{"logical_and", nvinfer1::ElementWiseOperation::kAND},
}; };
class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter { class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter {
...@@ -204,13 +236,41 @@ class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter { ...@@ -204,13 +236,41 @@ class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter {
public: public:
ElementwiseTensorPowOpConverter() { op_type_ = "pow"; } ElementwiseTensorPowOpConverter() { op_type_ = "pow"; }
}; };
class ElementwiseTensorFloorDivOpConverter class ElementwiseTensorFloorDivOpConverter
: public ElementwiseTensorOpConverter { : public ElementwiseTensorOpConverter {
public: public:
ElementwiseTensorFloorDivOpConverter() { op_type_ = "floordiv"; } ElementwiseTensorFloorDivOpConverter() { op_type_ = "floordiv"; }
}; };
class ElementwiseTensorLessThanOpConverter
: public ElementwiseTensorOpConverter {
public:
ElementwiseTensorLessThanOpConverter() { op_type_ = "less_than"; }
};
class ElementwiseTensorGreaterThanOpConverter
: public ElementwiseTensorOpConverter {
public:
ElementwiseTensorGreaterThanOpConverter() { op_type_ = "greater_than"; }
};
class ElementwiseTensorLogicalOrOpConverter
: public ElementwiseTensorOpConverter {
public:
ElementwiseTensorLogicalOrOpConverter() { op_type_ = "logical_or"; }
};
class ElementwiseTensorLogicalXorOpConverter
: public ElementwiseTensorOpConverter {
public:
ElementwiseTensorLogicalXorOpConverter() { op_type_ = "logical_xor"; }
};
class ElementwiseTensorLogicalAndOpConverter
: public ElementwiseTensorOpConverter {
public:
ElementwiseTensorLogicalAndOpConverter() { op_type_ = "logical_and"; }
};
class ElementwiseTensorLessEqualOpConverter
: public ElementwiseTensorOpConverter {
public:
ElementwiseTensorLessEqualOpConverter() { op_type_ = "less_equal"; }
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -248,3 +308,10 @@ REGISTER_TRT_OP_CONVERTER(elementwise_pow_tensor, ...@@ -248,3 +308,10 @@ REGISTER_TRT_OP_CONVERTER(elementwise_pow_tensor,
ElementwiseTensorPowOpConverter); ElementwiseTensorPowOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_floordiv_tensor, REGISTER_TRT_OP_CONVERTER(elementwise_floordiv_tensor,
ElementwiseTensorFloorDivOpConverter); ElementwiseTensorFloorDivOpConverter);
REGISTER_TRT_OP_CONVERTER(less_than, ElementwiseTensorLessThanOpConverter);
REGISTER_TRT_OP_CONVERTER(greater_than,
ElementwiseTensorGreaterThanOpConverter);
REGISTER_TRT_OP_CONVERTER(logical_or, ElementwiseTensorLogicalOrOpConverter);
REGISTER_TRT_OP_CONVERTER(logical_xor, ElementwiseTensorLogicalXorOpConverter);
REGISTER_TRT_OP_CONVERTER(logical_and, ElementwiseTensorLogicalAndOpConverter);
REGISTER_TRT_OP_CONVERTER(less_equal, ElementwiseTensorLessEqualOpConverter);
...@@ -1322,6 +1322,32 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1322,6 +1322,32 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
} }
if (op_type == "less_than" || op_type == "greater_than" ||
op_type == "logical_or" || op_type == "logical_xor" ||
op_type == "logical_and" || op_type == "less_equal") {
#if IS_TRT_VERSION_GE(8400)
if (!with_dynamic_shape) {
VLOG(3) << "these ops do not support static shape yet";
return false;
}
if (op_type == "logical_or" || op_type == "logical_xor" ||
op_type == "logical_and") {
auto* block = desc.Block();
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
auto x_dtype = x_var_desc->GetDataType();
auto y_dtype = y_var_desc->GetDataType();
if (x_dtype != framework::proto::VarType::BOOL ||
y_dtype != framework::proto::VarType::BOOL) {
VLOG(3) << "the op only support input of BOOL.";
return false;
}
}
#else
VLOG(3) << "these are not supported when TensorRT < 8.4";
return false;
#endif
}
if (op_type == "elementwise_add" || op_type == "elementwise_mul" || if (op_type == "elementwise_add" || op_type == "elementwise_mul" ||
op_type == "elementwise_sub" || op_type == "elementwise_div" || op_type == "elementwise_sub" || op_type == "elementwise_div" ||
op_type == "elementwise_pow" || op_type == "elementwise_min" || op_type == "elementwise_pow" || op_type == "elementwise_min" ||
...@@ -2382,6 +2408,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2382,6 +2408,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_max", "elementwise_max",
"elementwise_floordiv", "elementwise_floordiv",
"equal", "equal",
"less_than",
"greater_than",
"logical_or",
"logical_xor",
"logical_and",
"less_equal",
"dropout", "dropout",
"fill_any_like", "fill_any_like",
"prelu", "prelu",
...@@ -2514,6 +2546,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2514,6 +2546,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_max", "elementwise_max",
"elementwise_floordiv", "elementwise_floordiv",
"equal", "equal",
"less_than",
"greater_than",
"logical_or",
"logical_xor",
"logical_and",
"less_equal",
"dropout", "dropout",
"fill_any_like", "fill_any_like",
"prelu", "prelu",
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertLogicalTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
for shape in [[2, 16], [2, 16, 32], [1, 32, 16, 32]]:
for op_type in ["logical_and", "logical_or", "logical_xor"]:
for axis in [-1]:
self.dims = len(shape)
dics = [
{"axis": axis},
{"in_dtype": 5, "out_dtype": 0},
{"in_dtype": 0, "out_dtype": 5},
]
ops_config = [
{
"op_type": "cast",
"op_inputs": {"X": ["input_data1"]},
"op_outputs": {"Out": ["cast_output_data1"]},
"op_attrs": dics[1],
"outputs_dtype": {"cast_output_data1": np.bool},
},
{
"op_type": "cast",
"op_inputs": {"X": ["input_data2"]},
"op_outputs": {"Out": ["cast_output_data3"]},
"op_attrs": dics[1],
"outputs_dtype": {"cast_output_data1": np.bool},
},
{
"op_type": op_type,
"op_inputs": {
"X": ["cast_output_data1"],
"Y": ["cast_output_data3"],
},
"op_outputs": {"Out": ["cast_output_data0"]},
"op_attrs": dics[0],
},
{
"op_type": "cast",
"op_inputs": {"X": ["cast_output_data0"]},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[2],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input, shape)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input, shape)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 2:
self.dynamic_shape.min_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
if self.dims == 3:
self.dynamic_shape.min_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
if self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape:
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8400:
return 0, 7
return 1, 3
return 0, 7
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
class TrtConvertCompareTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
for shape in [[2, 16], [2, 16, 32], [1, 32, 16, 32]]:
for op_type in ["less_than", "greater_than"]:
for axis in [-1]:
self.dims = len(shape)
dics = [
{"axis": axis},
{"in_dtype": 0, "out_dtype": 5},
]
ops_config = [
{
"op_type": op_type,
"op_inputs": {
"X": ["input_data1"],
"Y": ["input_data2"],
},
"op_outputs": {"Out": ["cast_output_data0"]},
"op_attrs": dics[0],
},
{
"op_type": "cast",
"op_inputs": {"X": ["cast_output_data0"]},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[1],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input, shape)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input, shape)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 2:
self.dynamic_shape.min_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
if self.dims == 3:
self.dynamic_shape.min_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
if self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8400:
return 0, 5
if not dynamic_shape:
return 0, 5
return 1, 3
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
class TrtConvertLessEqualTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
for shape in [[2, 16], [2, 16, 32], [1, 32, 16, 32]]:
for op_type in ["less_equal"]:
for axis in [-1]:
self.dims = len(shape)
dics = [
{"axis": axis},
{"in_dtype": 5, "out_dtype": 2},
{"in_dtype": 0, "out_dtype": 5},
]
ops_config = [
{
"op_type": "cast",
"op_inputs": {"X": ["input_data1"]},
"op_outputs": {"Out": ["cast_output_data1"]},
"op_attrs": dics[1],
},
{
"op_type": "cast",
"op_inputs": {"X": ["input_data2"]},
"op_outputs": {"Out": ["cast_output_data2"]},
"op_attrs": dics[1],
},
{
"op_type": op_type,
"op_inputs": {
"X": ["cast_output_data1"],
"Y": ["cast_output_data2"],
},
"op_outputs": {"Out": ["cast_output_data0"]},
"op_attrs": dics[0],
},
{
"op_type": "cast",
"op_inputs": {"X": ["cast_output_data0"]},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[2],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input, shape)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input, shape)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 2:
self.dynamic_shape.min_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 16],
"input_data2": [2, 16],
}
if self.dims == 3:
self.dynamic_shape.min_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 16, 32],
"input_data2": [2, 16, 32],
}
if self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 32, 16, 32],
"input_data2": [1, 32, 16, 32],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
ver = paddle_infer.get_trt_compile_version()
if (
ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8400
or not dynamic_shape
):
return 2, 5
else:
return 1, 3
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册