未验证 提交 dca81a43 编写于 作者: W Wang Bojun 提交者: GitHub

[TRT] Fix conv2d filter of trt elementwiseadd_trans fusion UT (#51294)

* fix conv2d filter
上级 82a7c33e
...@@ -29,12 +29,12 @@ class TrtConvertElementwiseaddTransposeTest(TrtLayerAutoScanTest): ...@@ -29,12 +29,12 @@ class TrtConvertElementwiseaddTransposeTest(TrtLayerAutoScanTest):
def sample_program_configs(self): def sample_program_configs(self):
def conv_filter_datagen(dics): def conv_filter_datagen(dics):
c = dics["c"] c = dics["c"]
x = (np.random.randn(c, c, 1, 1)) / np.sqrt(c) x = (np.random.randn(c, c, 1, 1)) * np.sqrt(2 / c) * 0.1
return x.astype(np.float32) return x.astype(np.float32)
def conv_elementwise_bias_datagen(dics): def conv_elementwise_bias_datagen(dics):
c = dics["c"] c = dics["c"]
x = np.random.random([dics["c"]]) * 0.1 x = np.random.random([dics["c"]]) * 0.01
return x.astype(np.float32) return x.astype(np.float32)
def ele1_input_datagen(dics): def ele1_input_datagen(dics):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册