未验证 提交 230b9a82 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT][Cherry-Pick]Fix cast bug (#46293)

* fix cast bug
上级 a43f960e
...@@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter { ...@@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter {
switch (out_dtype) { switch (out_dtype) {
case 2: // INT32 = 2 case 2: // INT32 = 2
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); layer->setOutputType(0, nvinfer1::DataType::kINT32);
break; break;
case 4: // FP16 = 4 case 4: // FP16 = 4
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); layer->setOutputType(0, nvinfer1::DataType::kHALF);
break; break;
case 5: // FP32 = 5 case 5: // FP32 = 5
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
break; break;
default: default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
...@@ -49,9 +49,15 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -49,9 +49,15 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
else: else:
return np.ones([1, 3, 64, 64]).astype(np.float32) return np.ones([1, 3, 64, 64]).astype(np.float32)
for in_dtype in [0, 2, 4, 5, 6]: for in_dtype in [0, 2, 5, 6]:
for out_dtype in [0, 2, 4, 5, 6]: for out_dtype in [0, 2, 5, 6]:
dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}] dics = [{
"in_dtype": in_dtype,
"out_dtype": out_dtype
}, {
"in_dtype": out_dtype,
"out_dtype": in_dtype
}]
ops_config = [{ ops_config = [{
"op_type": "cast", "op_type": "cast",
...@@ -59,10 +65,20 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -59,10 +65,20 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
"X": ["input_data"] "X": ["input_data"]
}, },
"op_outputs": { "op_outputs": {
"Out": ["cast_output_data"] "Out": ["cast_output_data0"]
}, },
"op_attrs": dics[0] "op_attrs": dics[0]
}, {
"op_type": "cast",
"op_inputs": {
"X": ["cast_output_data0"]
},
"op_outputs": {
"Out": ["cast_output_data1"]
},
"op_attrs": dics[1]
}] }]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
program_config = ProgramConfig( program_config = ProgramConfig(
...@@ -72,7 +88,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): ...@@ -72,7 +88,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest):
"input_data": "input_data":
TensorConfig(data_gen=partial(generate_input, in_dtype)) TensorConfig(data_gen=partial(generate_input, in_dtype))
}, },
outputs=["cast_output_data"]) outputs=["cast_output_data1"])
yield program_config yield program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册