diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py index f66822171cb58a7671ecffc294a8386c6e42ebc4..0f035d60262a26e11169f87c2b0c7e812705f481 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py @@ -25,7 +25,8 @@ from paddle.fluid.core import AnalysisConfig class FCFusePassTRTTest(InferencePassTest): def setUp(self): with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data(name="data", shape=[32, 128], dtype="float32") + data = fluid.data( + name="data", shape=[32, 128, 2, 2], dtype="float32") fc_out1 = fluid.layers.fc(input=data, size=128, num_flatten_dims=1, @@ -35,10 +36,12 @@ class FCFusePassTRTTest(InferencePassTest): num_flatten_dims=1) out = fluid.layers.softmax(input=fc_out2) - self.feeds = {"data": np.random.random((32, 128)).astype("float32")} + self.feeds = { + "data": np.random.random((32, 128, 2, 2)).astype("float32") + } self.enable_trt = True self.trt_parameters = FCFusePassTRTTest.TensorRTParam( - 1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False) + 1 << 30, 32, 3, AnalysisConfig.Precision.Float32, False, False) self.fetch_list = [out] def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py index c85c54c74187f0ca77cd62928d15040a4b70adde..41f02b0427d68216f7363236b77ddf3229e92143 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py @@ -42,7 +42,7 @@ class TransposeFlattenConcatFusePassTRTTest(InferencePassTest): } self.enable_trt = True self.trt_parameters = TransposeFlattenConcatFusePassTRTTest.TensorRTParam( - 1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False) + 1 << 20, 8, 3, AnalysisConfig.Precision.Float32, False, False) self.fetch_list = [out] def test_check_output(self):