未验证 提交 ac894ced 编写于 作者: Z zlsh80826 提交者: GitHub

Test only trt group norm (#39561)

上级 b2986bab
...@@ -28,25 +28,14 @@ class TRTGroupNormTest(InferencePassTest): ...@@ -28,25 +28,14 @@ class TRTGroupNormTest(InferencePassTest):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data( data = fluid.data(
name="data", shape=[-1, 512, 12, 12], dtype="float32") name="data", shape=[-1, 512, 12, 12], dtype="float32")
relu_out = fluid.layers.relu(data) out = self.append_group_norm(data)
relu6_out = fluid.layers.relu6(relu_out)
tanh_out = fluid.layers.tanh(relu6_out)
conv_out = fluid.layers.conv2d(
input=tanh_out,
num_filters=512,
filter_size=3,
groups=1,
padding=[1, 1],
bias_attr=False,
act=None)
out = self.append_group_norm(conv_out)
self.feeds = { self.feeds = {
"data": np.random.random([1, 512, 12, 12]).astype("float32"), "data": np.random.random([1, 512, 12, 12]).astype("float32"),
} }
self.enable_trt = True self.enable_trt = True
self.trt_parameters = TRTGroupNormTest.TensorRTParam( self.trt_parameters = TRTGroupNormTest.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False) 1 << 30, 1, 1, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = TRTGroupNormTest.DynamicShapeParam({ self.dynamic_shape_params = TRTGroupNormTest.DynamicShapeParam({
'data': [1, 512, 12, 12] 'data': [1, 512, 12, 12]
}, {'data': [1, 512, 12, 12]}, {'data': [1, 512, 12, 12]}, False) }, {'data': [1, 512, 12, 12]}, {'data': [1, 512, 12, 12]}, False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册