From 5d94618da05c69fae6e107bc701f04b645b85b2e Mon Sep 17 00:00:00 2001 From: ccrrong <101700995+ccrrong@users.noreply.github.com> Date: Fri, 29 Jul 2022 15:05:19 +0800 Subject: [PATCH] skip cast trt convert when input dtype is bool (#44716) * skip cast trt convert when input dtype is bool --- paddle/fluid/inference/tensorrt/op_teller.cc | 12 +++++++----- .../unittests/ir/inference/test_trt_convert_cast.py | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index c602f2ff071..0c2a4d473e6 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2091,12 +2091,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, VLOG(3) << "unsupport data type conversion"; return false; } - if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2 || - in_dtype == 0) && + if (in_dtype == 0) { + VLOG(3) << "do not support input data type as bool now"; + return false; + } + if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) && (out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) { - VLOG(3) - << "only valid conversions are: " - "(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)"; + VLOG(3) << "only valid conversions are: " + "(kFLOAT | kHALF | kINT32) -> (kFLOAT | kHALF | kINT32)"; return false; } } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py index c381dbc2d6a..8dca14c02aa 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py @@ -27,10 +27,12 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) ] + if attrs[0]['in_dtype'] == 0: + return False if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4: return False if attrs[0]['in_dtype'] not in [ - 0, 2, 4, 5 + 2, 4, 5 ] or attrs[0]['out_dtype'] not in [2, 4, 5]: return False return True -- GitLab