未验证 提交 18adbbd0 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT]Fix cast converter bug , use setOutputType() instaead (#46289)

* fix cast bug
上级 42d9fe2f
......@@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter {
switch (out_dtype) {
case 2: // INT32 = 2
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
layer->setOutputType(0, nvinfer1::DataType::kINT32);
break;
case 4: // FP16 = 4
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
layer->setOutputType(0, nvinfer1::DataType::kHALF);
break;
case 5: // FP32 = 5
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
break;
default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
......@@ -47,18 +47,28 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
else:
return np.ones([1, 3, 64, 64]).astype(np.float32)
for in_dtype in [0, 2, 4, 5, 6]:
for out_dtype in [0, 2, 4, 5, 6]:
dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}]
for in_dtype in [0, 2, 5, 6]:
for out_dtype in [0, 2, 5, 6]:
dics = [
{"in_dtype": in_dtype, "out_dtype": out_dtype},
{"in_dtype": out_dtype, "out_dtype": in_dtype},
]
ops_config = [
{
"op_type": "cast",
"op_inputs": {"X": ["input_data"]},
"op_outputs": {"Out": ["cast_output_data"]},
"op_outputs": {"Out": ["cast_output_data0"]},
"op_attrs": dics[0],
}
},
{
"op_type": "cast",
"op_inputs": {"X": ["cast_output_data0"]},
"op_outputs": {"Out": ["cast_output_data1"]},
"op_attrs": dics[1],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
......@@ -69,7 +79,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
data_gen=partial(generate_input, in_dtype)
)
},
outputs=["cast_output_data"],
outputs=["cast_output_data1"],
)
yield program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册