未验证 提交 230b9a82 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT][Cherry-Pick]Fix cast bug (#46293)

* fix cast bug
上级 a43f960e
......@@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter {
switch (out_dtype) {
case 2: // INT32 = 2
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
layer->setOutputType(0, nvinfer1::DataType::kINT32);
break;
case 4: // FP16 = 4
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
layer->setOutputType(0, nvinfer1::DataType::kHALF);
break;
case 5: // FP32 = 5
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
break;
default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
......@@ -49,9 +49,15 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
else:
return np.ones([1, 3, 64, 64]).astype(np.float32)
for in_dtype in [0, 2, 4, 5, 6]:
for out_dtype in [0, 2, 4, 5, 6]:
dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}]
for in_dtype in [0, 2, 5, 6]:
for out_dtype in [0, 2, 5, 6]:
dics = [{
"in_dtype": in_dtype,
"out_dtype": out_dtype
}, {
"in_dtype": out_dtype,
"out_dtype": in_dtype
}]
ops_config = [{
"op_type": "cast",
......@@ -59,10 +65,20 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
"X": ["input_data"]
},
"op_outputs": {
"Out": ["cast_output_data"]
"Out": ["cast_output_data0"]
},
"op_attrs": dics[0]
}, {
"op_type": "cast",
"op_inputs": {
"X": ["cast_output_data0"]
},
"op_outputs": {
"Out": ["cast_output_data1"]
},
"op_attrs": dics[1]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
......@@ -72,7 +88,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
"input_data":
TensorConfig(data_gen=partial(generate_input, in_dtype))
},
outputs=["cast_output_data"])
outputs=["cast_output_data1"])
yield program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册