未验证 提交 ff44df18 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Support cast trt converter of bool input and output . (#48043)

* add_cast_bool

* cast
上级 bf6af816
...@@ -42,14 +42,21 @@ class CastOpConverter : public OpConverter { ...@@ -42,14 +42,21 @@ class CastOpConverter : public OpConverter {
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input);
switch (out_dtype) { switch (out_dtype) {
case 0: // BOOL = 0
layer->setOutputType(0, nvinfer1::DataType::kBOOL);
layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL);
break;
case 2: // INT32 = 2 case 2: // INT32 = 2
layer->setOutputType(0, nvinfer1::DataType::kINT32); layer->setOutputType(0, nvinfer1::DataType::kINT32);
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
break; break;
case 4: // FP16 = 4 case 4: // FP16 = 4
layer->setOutputType(0, nvinfer1::DataType::kHALF); layer->setOutputType(0, nvinfer1::DataType::kHALF);
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
break; break;
case 5: // FP32 = 5 case 5: // FP32 = 5
layer->setOutputType(0, nvinfer1::DataType::kFLOAT); layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
break; break;
default: default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
...@@ -2124,10 +2124,15 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2124,10 +2124,15 @@ struct SimpleOpTypeSetTeller : public Teller {
VLOG(3) << "unsupport data type conversion"; VLOG(3) << "unsupport data type conversion";
return false; return false;
} }
if (in_dtype == 0) { #if IS_TRT_VERSION_GE(8400)
VLOG(3) << "do not support input data type as bool now"; if (in_dtype == 0 || out_dtype == 0) {
return false; if (with_dynamic_shape) {
VLOG(3) << "the cast op supports inputs and outputs of BOOL by "
"trt8.4 above ";
return true;
}
} }
#endif
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) && if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) &&
(out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) { (out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) {
VLOG(3) << "only valid conversions are: " VLOG(3) << "only valid conversions are: "
......
...@@ -30,9 +30,16 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -30,9 +30,16 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
return False return False
if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4: if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4:
return False return False
if attrs[0]['in_dtype'] not in [2, 4, 5] or attrs[0][
'out_dtype' out_dtype = [2, 4, 5]
] not in [2, 4, 5]: ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 > 8400:
out_dtype.insert(3, 0)
if (
attrs[0]['in_dtype'] not in [2, 4, 5]
or attrs[0]['out_dtype'] not in out_dtype
):
return False return False
return True return True
...@@ -49,6 +56,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -49,6 +56,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
for in_dtype in [0, 2, 5, 6]: for in_dtype in [0, 2, 5, 6]:
for out_dtype in [0, 2, 5, 6]: for out_dtype in [0, 2, 5, 6]:
self.out_dtype = out_dtype
dics = [ dics = [
{"in_dtype": in_dtype, "out_dtype": out_dtype}, {"in_dtype": in_dtype, "out_dtype": out_dtype},
{"in_dtype": out_dtype, "out_dtype": in_dtype}, {"in_dtype": out_dtype, "out_dtype": in_dtype},
...@@ -89,7 +97,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -89,7 +97,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 64, 64]} self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 64, 64]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} self.dynamic_shape.max_input_shape = {"input_data": [1, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
def clear_dynamic_shape(): def clear_dynamic_shape():
...@@ -98,6 +106,8 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -98,6 +106,8 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {} self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and self.out_dtype == 0:
return 0, 4
return 1, 2 return 1, 2
attrs = [ attrs = [
......
...@@ -195,7 +195,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -195,7 +195,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape: if not dynamic_shape:
return 0, 6 return 0, 6
return 1, 5 return 1, 4
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册