From 18adbbd0fe9d7bf9449086f56db0e6200dd74137 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 8 Nov 2022 10:51:14 +0800 Subject: [PATCH] [Paddle-TRT]Fix cast converter bug , use setOutputType() instaead (#46289) * fix cast bug --- .../inference/tensorrt/convert/cast_op.cc | 6 ++--- .../ir/inference/test_trt_convert_cast.py | 22 ++++++++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/cast_op.cc b/paddle/fluid/inference/tensorrt/convert/cast_op.cc index ab62c43d85..b2b06744d9 100644 --- a/paddle/fluid/inference/tensorrt/convert/cast_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/cast_op.cc @@ -43,13 +43,13 @@ class CastOpConverter : public OpConverter { switch (out_dtype) { case 2: // INT32 = 2 - layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); + layer->setOutputType(0, nvinfer1::DataType::kINT32); break; case 4: // FP16 = 4 - layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); + layer->setOutputType(0, nvinfer1::DataType::kHALF); break; case 5: // FP32 = 5 - layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); break; default: LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py index 76b46313f9..c063019a8f 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py @@ -47,18 +47,28 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): else: return np.ones([1, 3, 64, 64]).astype(np.float32) - for in_dtype in [0, 2, 4, 5, 6]: - for out_dtype in [0, 2, 4, 5, 6]: - dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}] + for in_dtype in [0, 2, 5, 6]: + for out_dtype in [0, 2, 5, 6]: + dics = [ + {"in_dtype": in_dtype, "out_dtype": out_dtype}, + {"in_dtype": out_dtype, "out_dtype": in_dtype}, + ] ops_config = [ { "op_type": "cast", "op_inputs": {"X": ["input_data"]}, - "op_outputs": {"Out": ["cast_output_data"]}, + "op_outputs": {"Out": ["cast_output_data0"]}, "op_attrs": dics[0], - } + }, + { + "op_type": "cast", + "op_inputs": {"X": ["cast_output_data0"]}, + "op_outputs": {"Out": ["cast_output_data1"]}, + "op_attrs": dics[1], + }, ] + ops = self.generate_op_config(ops_config) program_config = ProgramConfig( @@ -69,7 +79,7 @@ class TrtConvertCastTest(TrtLayerAutoScanTest): data_gen=partial(generate_input, in_dtype) ) }, - outputs=["cast_output_data"], + outputs=["cast_output_data1"], ) yield program_config -- GitLab