未验证 提交 55ccb429 编写于 作者: W wenbin 提交者: GitHub

fix cast issue (#49909)

* fix cast issue

* add ut
上级 26140ec8
...@@ -46,6 +46,7 @@ class CastOpConverter : public OpConverter { ...@@ -46,6 +46,7 @@ class CastOpConverter : public OpConverter {
layer->setOutputType(0, nvinfer1::DataType::kBOOL); layer->setOutputType(0, nvinfer1::DataType::kBOOL);
break; break;
case 2: // INT32 = 2 case 2: // INT32 = 2
case 3: // INT64 = 3 there is no int64 in tensorrt subgraph
layer->setOutputType(0, nvinfer1::DataType::kINT32); layer->setOutputType(0, nvinfer1::DataType::kINT32);
break; break;
case 4: // FP16 = 4 case 4: // FP16 = 4
......
...@@ -29,9 +29,9 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -29,9 +29,9 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
] ]
if attrs[0]['in_dtype'] not in [0, 1, 2, 4, 5] or attrs[0][ if attrs[0]['in_dtype'] not in [0, 1, 2, 3, 4, 5] or attrs[0][
'out_dtype' 'out_dtype'
] not in [0, 1, 2, 4, 5]: ] not in [0, 1, 2, 3, 4, 5]:
return False return False
compile_version = paddle_infer.get_trt_compile_version() compile_version = paddle_infer.get_trt_compile_version()
runtime_version = paddle_infer.get_trt_runtime_version() runtime_version = paddle_infer.get_trt_runtime_version()
...@@ -55,8 +55,14 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -55,8 +55,14 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
def generate_input(type): def generate_input(type):
return np.ones([1, 3, 64, 64]).astype(type) return np.ones([1, 3, 64, 64]).astype(type)
for in_dtype in [np.bool_, np.int32, np.float32, np.float64]: for in_dtype in [np.bool_, np.int32, np.float32, np.float64, np.int64]:
for out_dtype in [np.bool_, np.int32, np.float32, np.float64]: for out_dtype in [
np.bool_,
np.int32,
np.float32,
np.float64,
np.int64,
]:
self.has_bool_dtype = (in_dtype == np.bool_) or ( self.has_bool_dtype = (in_dtype == np.bool_) or (
out_dtype == np.bool_ out_dtype == np.bool_
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册