未验证 提交 5d94618d 编写于 作者: C ccrrong 提交者: GitHub

skip cast trt convert when input dtype is bool (#44716)

* skip cast trt convert when input dtype is bool
上级 72f2ed43
...@@ -2091,12 +2091,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -2091,12 +2091,14 @@ bool OpTeller::Tell(const framework::ir::Node* node,
VLOG(3) << "unsupport data type conversion"; VLOG(3) << "unsupport data type conversion";
return false; return false;
} }
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2 || if (in_dtype == 0) {
in_dtype == 0) && VLOG(3) << "do not support input data type as bool now";
return false;
}
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) &&
(out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) { (out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) {
VLOG(3) VLOG(3) << "only valid conversions are: "
<< "only valid conversions are: " "(kFLOAT | kHALF | kINT32) -> (kFLOAT | kHALF | kINT32)";
"(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)";
return false; return false;
} }
} }
......
...@@ -27,10 +27,12 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -27,10 +27,12 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
] ]
if attrs[0]['in_dtype'] == 0:
return False
if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4: if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4:
return False return False
if attrs[0]['in_dtype'] not in [ if attrs[0]['in_dtype'] not in [
0, 2, 4, 5 2, 4, 5
] or attrs[0]['out_dtype'] not in [2, 4, 5]: ] or attrs[0]['out_dtype'] not in [2, 4, 5]:
return False return False
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册