未验证 提交 18adbbd0 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT]Fix cast converter bug , use setOutputType() instaead (#46289)

* fix cast bug
上级 42d9fe2f
...@@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter { ...@@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter {
switch (out_dtype) { switch (out_dtype) {
case 2: // INT32 = 2 case 2: // INT32 = 2
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); layer->setOutputType(0, nvinfer1::DataType::kINT32);
break; break;
case 4: // FP16 = 4 case 4: // FP16 = 4
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); layer->setOutputType(0, nvinfer1::DataType::kHALF);
break; break;
case 5: // FP32 = 5 case 5: // FP32 = 5
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
break; break;
default: default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
...@@ -47,18 +47,28 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -47,18 +47,28 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
else: else:
return np.ones([1, 3, 64, 64]).astype(np.float32) return np.ones([1, 3, 64, 64]).astype(np.float32)
for in_dtype in [0, 2, 4, 5, 6]: for in_dtype in [0, 2, 5, 6]:
for out_dtype in [0, 2, 4, 5, 6]: for out_dtype in [0, 2, 5, 6]:
dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}] dics = [
{"in_dtype": in_dtype, "out_dtype": out_dtype},
{"in_dtype": out_dtype, "out_dtype": in_dtype},
]
ops_config = [ ops_config = [
{ {
"op_type": "cast", "op_type": "cast",
"op_inputs": {"X": ["input_data"]}, "op_inputs": {"X": ["input_data"]},
"op_outputs": {"Out": ["cast_output_data"]}, "op_outputs": {"Out": ["cast_output_data0"]},
"op_attrs": dics[0], "op_attrs": dics[0],
} },
{
"op_type": "cast",
"op_inputs": {"X": ["cast_output_data0"]},
"op_outputs": {"Out": ["cast_output_data1"]},
"op_attrs": dics[1],
},
] ]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
program_config = ProgramConfig( program_config = ProgramConfig(
...@@ -69,7 +79,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -69,7 +79,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
data_gen=partial(generate_input, in_dtype) data_gen=partial(generate_input, in_dtype)
) )
}, },
outputs=["cast_output_data"], outputs=["cast_output_data1"],
) )
yield program_config yield program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册