未验证 提交 690d7a69 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt]set output data type of trt network (#49712)

* update trt engine to set in/out data type

* update

* Update engine.cc

* Update engine.cc

* update

* set engine output type before freeze the network

* update

* update trt autoscan ut

* update

* update ut

* fix equal bug, update ut

* fix cast and equal ut

* update cast ut using TRT < 8.4

* set datatype from scope

* check output var is nullptr

* Update op_converter.h

* update tensorrt_engine_op_test ut

* update
上级 a923a757
...@@ -44,19 +44,15 @@ class CastOpConverter : public OpConverter { ...@@ -44,19 +44,15 @@ class CastOpConverter : public OpConverter {
switch (out_dtype) { switch (out_dtype) {
case 0: // BOOL = 0 case 0: // BOOL = 0
layer->setOutputType(0, nvinfer1::DataType::kBOOL); layer->setOutputType(0, nvinfer1::DataType::kBOOL);
layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL);
break; break;
case 2: // INT32 = 2 case 2: // INT32 = 2
layer->setOutputType(0, nvinfer1::DataType::kINT32); layer->setOutputType(0, nvinfer1::DataType::kINT32);
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
break; break;
case 4: // FP16 = 4 case 4: // FP16 = 4
layer->setOutputType(0, nvinfer1::DataType::kHALF); layer->setOutputType(0, nvinfer1::DataType::kHALF);
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
break; break;
case 5: // FP32 = 5 case 5: // FP32 = 5
layer->setOutputType(0, nvinfer1::DataType::kFLOAT); layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
break; break;
default: default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
...@@ -363,9 +363,26 @@ class OpConverter { ...@@ -363,9 +363,26 @@ class OpConverter {
"check the INFO log above for more details.")); "check the INFO log above for more details."));
framework::proto::BlockDesc* block_proto = block_desc->Proto(); framework::proto::BlockDesc* block_proto = block_desc->Proto();
ConvertBlock(*block_proto, parameters, scope, engine); ConvertBlock(*block_proto, parameters, scope, engine);
for (auto& output : outputs) { for (auto& output : outputs) {
engine->DeclareOutput(output); auto* var = block_desc->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("no variable called %s in block.",
output.c_str()));
PADDLE_ENFORCE_EQ(
var->GetType(),
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument(
"The output tensor in TensorRT subgraph should be LoDTensor"));
engine->DeclareOutput(
output,
FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()));
VLOG(6) << "DeclareOutput(name: " << output << ", dtype: "
<< var->Proto()->type().lod_tensor().tensor().data_type() << ")";
} }
engine->FreezeNetwork(); engine->FreezeNetwork();
engine->ClearWeights(); engine->ClearWeights();
} }
......
...@@ -207,18 +207,6 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -207,18 +207,6 @@ void TensorRTEngine::FreezeNetwork() {
} }
} }
// If model is mixed precision, then we should cast all float output to
// float32 precision. Otherwise, we can not confirm the output precision of
// the trt engine.
if (model_precision_ != phi::DataType::FLOAT32) {
for (int i = 0; i < network()->getNbOutputs(); ++i) {
network()->getOutput(i)->setAllowedFormats(
static_cast<nvinfer1::TensorFormats>(
1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR)));
network()->getOutput(i)->setType(nvinfer1::DataType::kFLOAT);
}
}
if (use_dla_) { if (use_dla_) {
if (!enable_int8 && !enable_fp16) { if (!enable_int8 && !enable_fp16) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you " LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
...@@ -422,6 +410,14 @@ void TensorRTEngine::DeclareOutput(const std::string &name) { ...@@ -422,6 +410,14 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
name)); name));
network()->markOutput(*output); network()->markOutput(*output);
} }
void TensorRTEngine::DeclareOutput(const std::string &name,
nvinfer1::DataType dtype) {
auto *output = TensorRTEngine::GetITensor(name);
DeclareOutput(name);
output->setType(dtype);
}
void TensorRTEngine::DeleteITensor(const std::string &name, void TensorRTEngine::DeleteITensor(const std::string &name,
nvinfer1::ITensor *tensor) { nvinfer1::ITensor *tensor) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
......
...@@ -292,6 +292,9 @@ class TensorRTEngine { ...@@ -292,6 +292,9 @@ class TensorRTEngine {
const std::string& name); const std::string& name);
// Set the itensor_map_[name] as the network's output, and set its name. // Set the itensor_map_[name] as the network's output, and set its name.
void DeclareOutput(const std::string& name); void DeclareOutput(const std::string& name);
// Set the itensor_map_[name] as the network's output, and set its name and
// data type.
void DeclareOutput(const std::string& name, nvinfer1::DataType dtype);
void ClearTensorMap() { itensor_map_.clear(); } void ClearTensorMap() { itensor_map_.clear(); }
void DeleteITensor(const std::string& name, nvinfer1::ITensor* tensor); void DeleteITensor(const std::string& name, nvinfer1::ITensor* tensor);
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -1354,8 +1354,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1354,8 +1354,9 @@ struct SimpleOpTypeSetTeller : public Teller {
op_type == "logical_or" || op_type == "logical_xor" || op_type == "logical_or" || op_type == "logical_xor" ||
op_type == "logical_and" || op_type == "less_equal") { op_type == "logical_and" || op_type == "less_equal") {
#if IS_TRT_VERSION_GE(8400) #if IS_TRT_VERSION_GE(8400)
// TRT does not support kEQUAL/kGREATER/kLESS work with implicit batch
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
VLOG(3) << "these ops do not support static shape yet"; VLOG(3) << "Ops(" << op_type << ") do not support static shape yet.";
return false; return false;
} }
if (op_type == "logical_or" || op_type == "logical_xor" || if (op_type == "logical_or" || op_type == "logical_xor" ||
...@@ -2277,24 +2278,15 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2277,24 +2278,15 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
int in_dtype = PADDLE_GET_CONST(int, desc.GetAttr("in_dtype")); int in_dtype = PADDLE_GET_CONST(int, desc.GetAttr("in_dtype"));
int out_dtype = PADDLE_GET_CONST(int, desc.GetAttr("out_dtype")); int out_dtype = PADDLE_GET_CONST(int, desc.GetAttr("out_dtype"));
if ((in_dtype == 4 || in_dtype == 5) && out_dtype == 4) {
VLOG(3) << "unsupport data type conversion";
return false;
}
#if IS_TRT_VERSION_GE(8400)
if (in_dtype == 0 || out_dtype == 0) { if (in_dtype == 0 || out_dtype == 0) {
#if IS_TRT_VERSION_GE(8400)
if (with_dynamic_shape) { if (with_dynamic_shape) {
VLOG(3) << "the cast op supports inputs and outputs of BOOL by " VLOG(3) << "the cast op supports inputs and outputs of BOOL by "
"trt8.4 above "; "trt8.4 above ";
return true; return true;
} }
}
#endif #endif
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 3 ||
in_dtype == 2) &&
(out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) {
VLOG(3) << "only valid conversions are: "
"(kFLOAT | kHALF | kINT32) -> (kFLOAT | kHALF | kINT32)";
return false; return false;
} }
} }
...@@ -2339,9 +2331,15 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2339,9 +2331,15 @@ struct SimpleOpTypeSetTeller : public Teller {
if (op_type == "equal" || op_type == "not_equal") { if (op_type == "equal" || op_type == "not_equal") {
#if !IS_TRT_VERSION_GE(8000) #if !IS_TRT_VERSION_GE(8000)
VLOG(3) << "compare is not supported when TensorRT < 8.0"; VLOG(3) << "equal is not supported when TensorRT < 8.0";
return false; return false;
#else #else
// TRT does not support kEQUAL/kGREATER/kLESS work with implicit batch
if (!with_dynamic_shape) {
VLOG(3) << "the equal does not support "
"static shape yet";
return false;
}
int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis")); int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis"));
if (axis == 0) { if (axis == 0) {
return false; return false;
......
...@@ -92,6 +92,7 @@ void DynamicShapeTest(bool allow_build_at_runtime) { ...@@ -92,6 +92,7 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
AddTensorToBlockDesc(block_, "y", std::vector<int64_t>({4, 6})); AddTensorToBlockDesc(block_, "y", std::vector<int64_t>({4, 6}));
AddTensorToBlockDesc(block_, "y0", std::vector<int64_t>({6, 8})); AddTensorToBlockDesc(block_, "y0", std::vector<int64_t>({6, 8}));
AddTensorToBlockDesc(block_, "z", std::vector<int64_t>({2, 6})); AddTensorToBlockDesc(block_, "z", std::vector<int64_t>({2, 6}));
AddTensorToBlockDesc(block_, "z0", std::vector<int64_t>({8, 1, 1}));
// It is wired, need to copy manually. // It is wired, need to copy manually.
*block_->add_ops() = *fc0->Proto(); *block_->add_ops() = *fc0->Proto();
......
...@@ -59,6 +59,7 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest): ...@@ -59,6 +59,7 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest):
"flatten": flatten, "flatten": flatten,
"dtype": dtype, "dtype": dtype,
}, },
"outputs_dtype": {"arg_max_out": np.int32},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
......
...@@ -59,6 +59,7 @@ class TrtConvertArgMinTest(TrtLayerAutoScanTest): ...@@ -59,6 +59,7 @@ class TrtConvertArgMinTest(TrtLayerAutoScanTest):
"flatten": flatten, "flatten": flatten,
"dtype": dtype, "dtype": dtype,
}, },
"outputs_dtype": {"arg_min_out": np.int32},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
......
...@@ -21,6 +21,7 @@ from program_config import ProgramConfig, TensorConfig ...@@ -21,6 +21,7 @@ from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from paddle.framework import convert_np_dtype_to_dtype_
class TrtConvertCastTest(TrtLayerAutoScanTest): class TrtConvertCastTest(TrtLayerAutoScanTest):
...@@ -28,40 +29,46 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -28,40 +29,46 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
] ]
if attrs[0]['in_dtype'] == 0: if attrs[0]['in_dtype'] not in [0, 1, 2, 4, 5] or attrs[0][
'out_dtype'
] not in [0, 1, 2, 4, 5]:
return False return False
if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4: compile_version = paddle_infer.get_trt_compile_version()
runtime_version = paddle_infer.get_trt_runtime_version()
if (
compile_version[0] * 1000
+ compile_version[1] * 100
+ compile_version[2] * 10
< 8400
):
return False return False
out_dtype = [2, 4, 5]
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 > 8400:
out_dtype.insert(3, 0)
if ( if (
attrs[0]['in_dtype'] not in [2, 4, 5] runtime_version[0] * 1000
or attrs[0]['out_dtype'] not in out_dtype + runtime_version[1] * 100
+ runtime_version[2] * 10
< 8400
): ):
return False return False
return True return True
def sample_program_configs(self): def sample_program_configs(self):
def generate_input(type): def generate_input(type):
if type == 0: return np.ones([1, 3, 64, 64]).astype(type)
return np.ones([1, 3, 64, 64]).astype(np.bool)
elif type == 2: for in_dtype in [np.bool_, np.int32, np.float32, np.float64]:
return np.ones([1, 3, 64, 64]).astype(np.int32) for out_dtype in [np.bool_, np.int32, np.float32, np.float64]:
elif type == 4: self.has_bool_dtype = (in_dtype == np.bool_) or (
return np.ones([1, 3, 64, 64]).astype(np.float16) out_dtype == np.bool_
else: )
return np.ones([1, 3, 64, 64]).astype(np.float32)
for in_dtype in [0, 2, 5, 6]:
for out_dtype in [0, 2, 5, 6]:
self.out_dtype = out_dtype
dics = [ dics = [
{"in_dtype": in_dtype, "out_dtype": out_dtype}, {
{"in_dtype": out_dtype, "out_dtype": in_dtype}, "in_dtype": convert_np_dtype_to_dtype_(in_dtype),
"out_dtype": convert_np_dtype_to_dtype_(out_dtype),
},
{
"in_dtype": convert_np_dtype_to_dtype_(out_dtype),
"out_dtype": convert_np_dtype_to_dtype_(in_dtype),
},
] ]
ops_config = [ ops_config = [
...@@ -70,12 +77,14 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -70,12 +77,14 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
"op_inputs": {"X": ["input_data"]}, "op_inputs": {"X": ["input_data"]},
"op_outputs": {"Out": ["cast_output_data0"]}, "op_outputs": {"Out": ["cast_output_data0"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {"cast_output_data0": out_dtype},
}, },
{ {
"op_type": "cast", "op_type": "cast",
"op_inputs": {"X": ["cast_output_data0"]}, "op_inputs": {"X": ["cast_output_data0"]},
"op_outputs": {"Out": ["cast_output_data1"]}, "op_outputs": {"Out": ["cast_output_data1"]},
"op_attrs": dics[1], "op_attrs": dics[1],
"outputs_dtype": {"cast_output_data1": in_dtype},
}, },
] ]
...@@ -108,7 +117,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -108,7 +117,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {} self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and self.out_dtype == 0: if not dynamic_shape and self.has_bool_dtype:
return 0, 4 return 0, 4
return 1, 2 return 1, 2
......
...@@ -53,7 +53,7 @@ class TrtConvertLogicalTest(TrtLayerAutoScanTest): ...@@ -53,7 +53,7 @@ class TrtConvertLogicalTest(TrtLayerAutoScanTest):
"op_inputs": {"X": ["input_data2"]}, "op_inputs": {"X": ["input_data2"]},
"op_outputs": {"Out": ["cast_output_data3"]}, "op_outputs": {"Out": ["cast_output_data3"]},
"op_attrs": dics[1], "op_attrs": dics[1],
"outputs_dtype": {"cast_output_data1": np.bool}, "outputs_dtype": {"cast_output_data3": np.bool},
}, },
{ {
"op_type": op_type, "op_type": op_type,
...@@ -345,12 +345,14 @@ class TrtConvertLessEqualTest(TrtLayerAutoScanTest): ...@@ -345,12 +345,14 @@ class TrtConvertLessEqualTest(TrtLayerAutoScanTest):
"op_inputs": {"X": ["input_data1"]}, "op_inputs": {"X": ["input_data1"]},
"op_outputs": {"Out": ["cast_output_data1"]}, "op_outputs": {"Out": ["cast_output_data1"]},
"op_attrs": dics[1], "op_attrs": dics[1],
"outputs_dtype": {"cast_output_data1": np.int32},
}, },
{ {
"op_type": "cast", "op_type": "cast",
"op_inputs": {"X": ["input_data2"]}, "op_inputs": {"X": ["input_data2"]},
"op_outputs": {"Out": ["cast_output_data2"]}, "op_outputs": {"Out": ["cast_output_data2"]},
"op_attrs": dics[1], "op_attrs": dics[1],
"outputs_dtype": {"cast_output_data2": np.int32},
}, },
{ {
"op_type": op_type, "op_type": op_type,
......
...@@ -71,6 +71,11 @@ class TrtConvertElementwiseTest_one_input_special_case0(TrtLayerAutoScanTest): ...@@ -71,6 +71,11 @@ class TrtConvertElementwiseTest_one_input_special_case0(TrtLayerAutoScanTest):
}, },
"op_outputs": {"Out": ["output_data"]}, "op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"output_data": np.float32
if op_type != "elementwise_floordiv"
else np.int32
},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
...@@ -196,6 +201,11 @@ class TrtConvertElementwiseTest_one_input_special_case1(TrtLayerAutoScanTest): ...@@ -196,6 +201,11 @@ class TrtConvertElementwiseTest_one_input_special_case1(TrtLayerAutoScanTest):
"op_inputs": {"X": ["input_data"], "Y": ["weight"]}, "op_inputs": {"X": ["input_data"], "Y": ["weight"]},
"op_outputs": {"Out": ["output_data"]}, "op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"output_data": np.float32
if op_type != "elementwise_floordiv"
else np.int32
},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
...@@ -321,6 +331,11 @@ class TrtConvertElementwiseTest_one_input(TrtLayerAutoScanTest): ...@@ -321,6 +331,11 @@ class TrtConvertElementwiseTest_one_input(TrtLayerAutoScanTest):
}, },
"op_outputs": {"Out": ["output_data"]}, "op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"output_data": np.float32
if op_type != "elementwise_floordiv"
else np.int32
},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
...@@ -455,6 +470,11 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( ...@@ -455,6 +470,11 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
}, },
"op_outputs": {"Out": ["output_data"]}, "op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"output_data": np.float32
if op_type != "elementwise_floordiv"
else np.int32
},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
...@@ -647,6 +667,11 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest): ...@@ -647,6 +667,11 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest):
}, },
"op_outputs": {"Out": ["output_data"]}, "op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"output_data": np.float32
if op_type != "elementwise_floordiv"
else np.int32
},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
...@@ -782,6 +807,11 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -782,6 +807,11 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
}, },
"op_outputs": {"Out": ["output_data"]}, "op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"output_data": np.float32
if op_type != "elementwise_floordiv"
else np.int32
},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
......
...@@ -54,12 +54,16 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -54,12 +54,16 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
}, },
"op_outputs": {"Out": ["compare_output_data"]}, "op_outputs": {"Out": ["compare_output_data"]},
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"compare_output_data": np.bool_
},
}, },
{ {
"op_type": "cast", "op_type": "cast",
"op_inputs": {"X": ["compare_output_data"]}, "op_inputs": {"X": ["compare_output_data"]},
"op_outputs": {"Out": ["output_data"]}, "op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[1], "op_attrs": dics[1],
"outputs_dtype": {"output_data": np.float32},
}, },
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
...@@ -77,7 +81,6 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -77,7 +81,6 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
}, },
outputs=["output_data"], outputs=["output_data"],
) )
yield program_config yield program_config
def sample_predictor_configs( def sample_predictor_configs(
...@@ -104,8 +107,8 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -104,8 +107,8 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
"input_data2": [1, 1, 4], "input_data2": [1, 1, 4],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data1": [4, 1, 256], "input_data1": [4, 1, 32],
"input_data2": [1, 1, 256], "input_data2": [4, 1, 32],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 1, 16], "input_data1": [2, 1, 16],
...@@ -117,8 +120,8 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -117,8 +120,8 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
"input_data2": [1, 1, 4, 4], "input_data2": [1, 1, 4, 4],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data1": [4, 1, 128, 256], "input_data1": [4, 1, 64, 32],
"input_data2": [4, 1, 128, 256], "input_data2": [4, 1, 64, 32],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 1, 32, 16], "input_data1": [2, 1, 32, 16],
...@@ -131,9 +134,11 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -131,9 +134,11 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {} self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape:
return 0, 5
if self.dims == 1: if self.dims == 1:
return 0, 3 return 0, 3
return 1, 2 return 1, 3
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
...@@ -162,6 +167,7 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): ...@@ -162,6 +167,7 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
), 1e-3 ), 1e-3
def test(self): def test(self):
self.trt_param.workspace_size = 1 << 20
self.run_test() self.run_test()
......
...@@ -104,6 +104,7 @@ class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest): ...@@ -104,6 +104,7 @@ class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest):
"normalized": False, "normalized": False,
"nms_eta": nms_eta, "nms_eta": nms_eta,
}, },
"outputs_dtype": {"nms_output_index": np.int32},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
......
...@@ -54,6 +54,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -54,6 +54,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"Indices": ["indices_data"], "Indices": ["indices_data"],
}, },
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {"indices_data": np.int32},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
......
...@@ -71,6 +71,9 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -71,6 +71,9 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
"Indices": ["indices_data"], "Indices": ["indices_data"],
}, },
"op_attrs": dics[0], "op_attrs": dics[0],
"outputs_dtype": {
"indices_data": np.int32
},
} }
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册