From 4a52c0cc941dd21cebf8c390f86780b2ce62759a Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 5 Aug 2021 21:21:56 +0800 Subject: [PATCH] [paddle-trt] fix_teller_reshape (#34583) --- paddle/fluid/inference/tensorrt/op_teller.cc | 5 ++-- .../ir/inference/test_trt_reshape_op.py | 27 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 2829a740236..bfe3dfc85ee 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -703,8 +703,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; } // Paddle-TRT does not support the input tensors: Shape and ShapeTensor - if (desc.Input("Shape").size() >= 1 || - desc.Input("ShapeTensor").size() >= 1) { + auto reshape_inputs = desc.Inputs(); + if (reshape_inputs.find("Shape") != reshape_inputs.end() || + reshape_inputs.find("ShapeTensor") != reshape_inputs.end()) { return false; } std::vector shape = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py index 85054be534e..76dc605c3ec 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py @@ -80,6 +80,33 @@ class TRTReshapeTest1(TRTReshapeTest): class TRTReshapeTest2(TRTReshapeTest): + def setUp(self): + self.bs = 2 + self.input_shape = [23, 13, 24] + self.reshape = [2, 0, -1, 12] + self.data_shape = [ + self.bs, self.input_shape[0], self.input_shape[1], + self.input_shape[2] + ] + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name='data', shape=self.data_shape, dtype='float32') + actual_reshape = fluid.data( + name='actual_reshape', shape=[4], dtype='int32') + reshape_out = fluid.layers.reshape( + x=data, shape=self.reshape, actual_shape=actual_reshape) + out = fluid.layers.batch_norm(reshape_out, is_test=True) + self.feeds = { + 'data': np.random.random(self.data_shape).astype('float32'), + 'actual_reshape': np.array([2, 0, -1, 12]).astype('int32') + } + self.enable_trt = True + self.trt_parameters = TRTReshapeTest.TensorRTParam( + 1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + +class TRTReshapeTest3(TRTReshapeTest): def setUp(self): self.bs = 1 self.input_shape = [14, 48, 27] -- GitLab