未验证 提交 55ccb429 编写于 作者: W wenbin 提交者: GitHub

fix cast issue (#49909)

* fix cast issue

* add ut
上级 26140ec8
......@@ -46,6 +46,7 @@ class CastOpConverter : public OpConverter {
layer->setOutputType(0, nvinfer1::DataType::kBOOL);
break;
case 2: // INT32 = 2
case 3: // INT64 = 3 there is no int64 in tensorrt subgraph
layer->setOutputType(0, nvinfer1::DataType::kINT32);
break;
case 4: // FP16 = 4
......
......@@ -29,9 +29,9 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
if attrs[0]['in_dtype'] not in [0, 1, 2, 4, 5] or attrs[0][
if attrs[0]['in_dtype'] not in [0, 1, 2, 3, 4, 5] or attrs[0][
'out_dtype'
] not in [0, 1, 2, 4, 5]:
] not in [0, 1, 2, 3, 4, 5]:
return False
compile_version = paddle_infer.get_trt_compile_version()
runtime_version = paddle_infer.get_trt_runtime_version()
......@@ -55,8 +55,14 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
def generate_input(type):
return np.ones([1, 3, 64, 64]).astype(type)
for in_dtype in [np.bool_, np.int32, np.float32, np.float64]:
for out_dtype in [np.bool_, np.int32, np.float32, np.float64]:
for in_dtype in [np.bool_, np.int32, np.float32, np.float64, np.int64]:
for out_dtype in [
np.bool_,
np.int32,
np.float32,
np.float64,
np.int64,
]:
self.has_bool_dtype = (in_dtype == np.bool_) or (
out_dtype == np.bool_
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册