diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index c602f2ff071240d8e00e3f45e14118c32602f651..0c2a4d473e65ebbb80123b40d8f6ba60cfe00926 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2091,12 +2091,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, VLOG(3) << "unsupport data type conversion"; return false; } - if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2 || - in_dtype == 0) && + if (in_dtype == 0) { + VLOG(3) << "do not support input data type as bool now"; + return false; + } + if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) && (out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) { - VLOG(3) - << "only valid conversions are: " - "(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)"; + VLOG(3) << "only valid conversions are: " + "(kFLOAT | kHALF | kINT32) -> (kFLOAT | kHALF | kINT32)"; return false; } } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py index c381dbc2d6ab4d2e8e0f0c2e41ef1e0cdac03da6..8dca14c02aa74a2606d03eff9f853b3aaef5d04d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py @@ -27,10 +27,12 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) ] + if attrs[0]['in_dtype'] == 0: + return False if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4: return False if attrs[0]['in_dtype'] not in [ - 0, 2, 4, 5 + 2, 4, 5 ] or attrs[0]['out_dtype'] not in [2, 4, 5]: return False return True