diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 2829a740236d271ccf3af511e2afd731f3ab7cf5..bfe3dfc85eecdd966fbfa18d128a66373cd75dd7 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -703,8 +703,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; } // Paddle-TRT does not support the input tensors: Shape and ShapeTensor - if (desc.Input("Shape").size() >= 1 || - desc.Input("ShapeTensor").size() >= 1) { + auto reshape_inputs = desc.Inputs(); + if (reshape_inputs.find("Shape") != reshape_inputs.end() || + reshape_inputs.find("ShapeTensor") != reshape_inputs.end()) { return false; } std::vector shape = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py index 85054be534eeba5400f08e9c13b986fd5e192df0..76dc605c3ecd27d3435ee82aadee88a68ac3a666 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py @@ -80,6 +80,33 @@ class TRTReshapeTest1(TRTReshapeTest): class TRTReshapeTest2(TRTReshapeTest): + def setUp(self): + self.bs = 2 + self.input_shape = [23, 13, 24] + self.reshape = [2, 0, -1, 12] + self.data_shape = [ + self.bs, self.input_shape[0], self.input_shape[1], + self.input_shape[2] + ] + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name='data', shape=self.data_shape, dtype='float32') + actual_reshape = fluid.data( + name='actual_reshape', shape=[4], dtype='int32') + reshape_out = fluid.layers.reshape( + x=data, shape=self.reshape, actual_shape=actual_reshape) + out = fluid.layers.batch_norm(reshape_out, is_test=True) + self.feeds = { + 'data': np.random.random(self.data_shape).astype('float32'), + 'actual_reshape': np.array([2, 0, -1, 12]).astype('int32') + } + self.enable_trt = True + self.trt_parameters = TRTReshapeTest.TensorRTParam( + 1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + +class TRTReshapeTest3(TRTReshapeTest): def setUp(self): self.bs = 1 self.input_shape = [14, 48, 27]