diff --git a/paddle/fluid/inference/tensorrt/convert/cast_op.cc b/paddle/fluid/inference/tensorrt/convert/cast_op.cc index ab62c43d851eb84ab4c34dfb3629e1ab6ed22d26..b2b06744d984ab685de90558c4bbb2481d5ddd1b 100644 --- a/paddle/fluid/inference/tensorrt/convert/cast_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/cast_op.cc @@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter { switch (out_dtype) { case 2: // INT32 = 2 - layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); + layer->setOutputType(0, nvinfer1::DataType::kINT32); break; case 4: // FP16 = 4 - layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); + layer->setOutputType(0, nvinfer1::DataType::kHALF); break; case 5: // FP32 = 5 - layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); break; default: LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py index 76b46313f9590751c3ee31b5ef673e6336ec8a36..c063019a8f475d194ca06e8df4ca7ab01dfc6715 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py @@ -47,18 +47,28 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): else: return np.ones([1, 3, 64, 64]).astype(np.float32) - for in_dtype in [0, 2, 4, 5, 6]: - for out_dtype in [0, 2, 4, 5, 6]: - dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}] + for in_dtype in [0, 2, 5, 6]: + for out_dtype in [0, 2, 5, 6]: + dics = [ + {"in_dtype": in_dtype, "out_dtype": out_dtype}, + {"in_dtype": out_dtype, "out_dtype": in_dtype}, + ] ops_config = [ { "op_type": "cast", "op_inputs": {"X": ["input_data"]}, - "op_outputs": {"Out": ["cast_output_data"]}, + "op_outputs": {"Out": ["cast_output_data0"]}, "op_attrs": dics[0], - } + }, + { + "op_type": "cast", + "op_inputs": {"X": ["cast_output_data0"]}, + "op_outputs": {"Out": ["cast_output_data1"]}, + "op_attrs": dics[1], + }, ] + ops = self.generate_op_config(ops_config) program_config = ProgramConfig( @@ -69,7 +79,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): data_gen=partial(generate_input, in_dtype) ) }, - outputs=["cast_output_data"], + outputs=["cast_output_data1"], ) yield program_config