未验证 提交 ac894ced 编写于 作者: Z zlsh80826 提交者: GitHub

Test only trt group norm (#39561)

上级 b2986bab
......@@ -28,25 +28,14 @@ class TRTGroupNormTest(InferencePassTest):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 512, 12, 12], dtype="float32")
relu_out = fluid.layers.relu(data)
relu6_out = fluid.layers.relu6(relu_out)
tanh_out = fluid.layers.tanh(relu6_out)
conv_out = fluid.layers.conv2d(
input=tanh_out,
num_filters=512,
filter_size=3,
groups=1,
padding=[1, 1],
bias_attr=False,
act=None)
out = self.append_group_norm(conv_out)
out = self.append_group_norm(data)
self.feeds = {
"data": np.random.random([1, 512, 12, 12]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TRTGroupNormTest.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
1 << 30, 1, 1, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = TRTGroupNormTest.DynamicShapeParam({
'data': [1, 512, 12, 12]
}, {'data': [1, 512, 12, 12]}, {'data': [1, 512, 12, 12]}, False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册