未验证 提交 5d94618d 编写于 作者: C ccrrong 提交者: GitHub

skip cast trt convert when input dtype is bool (#44716)

* skip cast trt convert when input dtype is bool
上级 72f2ed43
......@@ -2091,12 +2091,14 @@ bool OpTeller::Tell(const framework::ir::Node* node,
VLOG(3) << "unsupport data type conversion";
return false;
}
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2 ||
in_dtype == 0) &&
if (in_dtype == 0) {
VLOG(3) << "do not support input data type as bool now";
return false;
}
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2) &&
(out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) {
VLOG(3)
<< "only valid conversions are: "
"(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)";
VLOG(3) << "only valid conversions are: "
"(kFLOAT | kHALF | kINT32) -> (kFLOAT | kHALF | kINT32)";
return false;
}
}
......
......@@ -27,10 +27,12 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
if attrs[0]['in_dtype'] == 0:
return False
if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4:
return False
if attrs[0]['in_dtype'] not in [
0, 2, 4, 5
2, 4, 5
] or attrs[0]['out_dtype'] not in [2, 4, 5]:
return False
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册