未验证 提交 38325636 编写于 作者: W wenbin 提交者: GitHub

copyright (#45866)

上级 749667e5
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
......@@ -61,6 +61,24 @@ class TestLayernormShiftPartitionPass(PassAutoScanTest):
})
yield config, ['layernorm_shift_partition'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({
"input_data": [1, 9, 96],
}, {
"input_data": [4, 3136, 768],
}, {
"input_data": [1, 784, 384],
})
yield config, ['layernorm_shift_partition'], (1e-3, 1e-3)
def sample_program_config(self, draw):
axis = [0, 1, 3, 2, 4, 5]
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
......@@ -198,10 +216,10 @@ class TestLayernormShiftPartitionPass(PassAutoScanTest):
def test(self):
self.run_and_statis(quant=False,
max_examples=20,
max_examples=50,
passes=["layernorm_shift_partition_fuse_pass"],
max_duration=250,
min_success_num=20)
min_success_num=50)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册