未验证 提交 ff44df18 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Support cast trt converter of bool input and output . (#48043)

* add_cast_bool

* cast
上级 bf6af816
......@@ -42,14 +42,21 @@ class CastOpConverter : public OpConverter {
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input);
switch (out_dtype) {
case 0: // BOOL = 0
layer->setOutputType(0, nvinfer1::DataType::kBOOL);
layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL);
break;
case 2: // INT32 = 2
layer->setOutputType(0, nvinfer1::DataType::kINT32);
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
break;
case 4: // FP16 = 4
layer->setOutputType(0, nvinfer1::DataType::kHALF);
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
break;
case 5: // FP32 = 5
layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
break;
default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
......@@ -2124,10 +2124,15 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << "unsupport data type conversion";
return false;
}
if (in_dtype == 0) {
VLOG(3) << "do not support input data type as bool now";
return false;
#if IS_TRT_VERSION_GE(8400)
if (in_dtype == 0 || out_dtype == 0) {
if (with_dynamic_shape) {
VLOG(3) << "the cast op supports inputs and outputs of BOOL by "
"trt8.4 above ";
return true;
}
}
#endif
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) &&
(out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) {
VLOG(3) << "only valid conversions are: "
......
......@@ -30,9 +30,16 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
return False
if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4:
return False
if attrs[0]['in_dtype'] not in [2, 4, 5] or attrs[0][
'out_dtype'
] not in [2, 4, 5]:
out_dtype = [2, 4, 5]
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 > 8400:
out_dtype.insert(3, 0)
if (
attrs[0]['in_dtype'] not in [2, 4, 5]
or attrs[0]['out_dtype'] not in out_dtype
):
return False
return True
......@@ -49,6 +56,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
for in_dtype in [0, 2, 5, 6]:
for out_dtype in [0, 2, 5, 6]:
self.out_dtype = out_dtype
dics = [
{"in_dtype": in_dtype, "out_dtype": out_dtype},
{"in_dtype": out_dtype, "out_dtype": in_dtype},
......@@ -89,7 +97,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 64, 64]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
self.dynamic_shape.max_input_shape = {"input_data": [1, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
def clear_dynamic_shape():
......@@ -98,6 +106,8 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and self.out_dtype == 0:
return 0, 4
return 1, 2
attrs = [
......
......@@ -195,7 +195,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape:
return 0, 6
return 1, 5
return 1, 4
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册