diff --git a/paddle/fluid/inference/tensorrt/convert/cast_op.cc b/paddle/fluid/inference/tensorrt/convert/cast_op.cc index ab62c43d851eb84ab4c34dfb3629e1ab6ed22d26..b2b06744d984ab685de90558c4bbb2481d5ddd1b 100644 --- a/paddle/fluid/inference/tensorrt/convert/cast_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/cast_op.cc @@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter { switch (out_dtype) { case 2: // INT32 = 2 - layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); + layer->setOutputType(0, nvinfer1::DataType::kINT32); break; case 4: // FP16 = 4 - layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); + layer->setOutputType(0, nvinfer1::DataType::kHALF); break; case 5: // FP32 = 5 - layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); break; default: LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py index 8dca14c02aa74a2606d03eff9f853b3aaef5d04d..3d01a0712aecc15830113d7b4fd528dbcac050e6 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py @@ -49,9 +49,15 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): else: return np.ones([1, 3, 64, 64]).astype(np.float32) - for in_dtype in [0, 2, 4, 5, 6]: - for out_dtype in [0, 2, 4, 5, 6]: - dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}] + for in_dtype in [0, 2, 5, 6]: + for out_dtype in [0, 2, 5, 6]: + dics = [{ + "in_dtype": in_dtype, + "out_dtype": out_dtype + }, { + "in_dtype": out_dtype, + "out_dtype": in_dtype + }] ops_config = [{ "op_type": "cast", @@ -59,10 +65,20 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): "X": ["input_data"] }, "op_outputs": { - "Out": ["cast_output_data"] + "Out": ["cast_output_data0"] }, "op_attrs": dics[0] + }, { + "op_type": "cast", + "op_inputs": { + "X": ["cast_output_data0"] + }, + "op_outputs": { + "Out": ["cast_output_data1"] + }, + "op_attrs": dics[1] }] + ops = self.generate_op_config(ops_config) program_config = ProgramConfig( @@ -72,7 +88,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): "input_data": TensorConfig(data_gen=partial(generate_input, in_dtype)) }, - outputs=["cast_output_data"]) + outputs=["cast_output_data1"]) yield program_config