未验证 提交 53cfac94 编写于 作者: L liu zhengxi 提交者: GitHub

Fix trt fc fuse test (#23852)

* fix trt fc fuse test, test=develop

* fix trt_transpose_flatten_concat shape, test=develop
上级 477cb1fd
......@@ -25,7 +25,8 @@ from paddle.fluid.core import AnalysisConfig
class FCFusePassTRTTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[32, 128], dtype="float32")
data = fluid.data(
name="data", shape=[32, 128, 2, 2], dtype="float32")
fc_out1 = fluid.layers.fc(input=data,
size=128,
num_flatten_dims=1,
......@@ -35,10 +36,12 @@ class FCFusePassTRTTest(InferencePassTest):
num_flatten_dims=1)
out = fluid.layers.softmax(input=fc_out2)
self.feeds = {"data": np.random.random((32, 128)).astype("float32")}
self.feeds = {
"data": np.random.random((32, 128, 2, 2)).astype("float32")
}
self.enable_trt = True
self.trt_parameters = FCFusePassTRTTest.TensorRTParam(
1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False)
1 << 30, 32, 3, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
......
......@@ -42,7 +42,7 @@ class TransposeFlattenConcatFusePassTRTTest(InferencePassTest):
}
self.enable_trt = True
self.trt_parameters = TransposeFlattenConcatFusePassTRTTest.TensorRTParam(
1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False)
1 << 20, 8, 3, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册