未验证 提交 ded7d190 编写于 作者: Y yuehuayingxueluo 提交者: GitHub

[BugFix]: Fix ci test bugs in test_fuse_gemm_epilogue_pass.py and...

[BugFix]: Fix ci test bugs in test_fuse_gemm_epilogue_pass.py and test_fused_gemm_epilogue_op.py (#54519)

* fix ci bugs in fused_linear

* fix code style
上级 40bfe0eb
...@@ -135,7 +135,7 @@ class TestFuseGemmEpilogueFWDBase(unittest.TestCase): ...@@ -135,7 +135,7 @@ class TestFuseGemmEpilogueFWDBase(unittest.TestCase):
"_matmul_y": self.matmul_y_arr, "_matmul_y": self.matmul_y_arr,
"_ele_y": self.ele_y_arr, "_ele_y": self.ele_y_arr,
} }
self.reference = self.exe.run( self.reference = paddle.static.Executor(self.place).run(
self.main_prog, feed=self.feed, fetch_list=[self.loss.name] self.main_prog, feed=self.feed, fetch_list=[self.loss.name]
) )
...@@ -206,10 +206,6 @@ class TestFuseGemmEpilogueReluFWDFP16(TestFuseGemmEpilogueReluFWDFP32): ...@@ -206,10 +206,6 @@ class TestFuseGemmEpilogueReluFWDFP16(TestFuseGemmEpilogueReluFWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list self.place, self.main_prog, to_fp16_var_names=fp16_var_list
) )
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
...@@ -239,10 +235,6 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32): ...@@ -239,10 +235,6 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list self.place, self.main_prog, to_fp16_var_names=fp16_var_list
) )
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
...@@ -318,7 +310,8 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase): ...@@ -318,7 +310,8 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase):
f'{multi_layer.linear3.full_name()}.w_0@GRAD', f'{multi_layer.linear3.full_name()}.w_0@GRAD',
f'{multi_layer.linear3.full_name()}.b_0@GRAD', f'{multi_layer.linear3.full_name()}.b_0@GRAD',
] ]
self.outs_ref = self.exe.run(
self.outs_ref = paddle.static.Executor(self.place).run(
self.main_prog, feed=self.feed, fetch_list=self.fetch self.main_prog, feed=self.feed, fetch_list=self.fetch
) )
...@@ -402,10 +395,6 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32): ...@@ -402,10 +395,6 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list self.place, self.main_prog, to_fp16_var_names=fp16_var_list
) )
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
...@@ -435,10 +424,6 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32): ...@@ -435,10 +424,6 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list self.place, self.main_prog, to_fp16_var_names=fp16_var_list
) )
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(0) np.random.seed(0)
......
...@@ -89,7 +89,9 @@ class TestFuseGemmEpilogueOpReluMMFP16(TestFuseGemmBase): ...@@ -89,7 +89,9 @@ class TestFuseGemmEpilogueOpReluMMFP16(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
...@@ -146,7 +148,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16(TestFuseGemmBase): ...@@ -146,7 +148,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
...@@ -203,7 +207,9 @@ class TestFuseGemmEpilogueOpReluMMTFP16(TestFuseGemmBase): ...@@ -203,7 +207,9 @@ class TestFuseGemmEpilogueOpReluMMTFP16(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
...@@ -260,7 +266,9 @@ class TestFuseGemmEpilogueOpReluMTMTFP16(TestFuseGemmBase): ...@@ -260,7 +266,9 @@ class TestFuseGemmEpilogueOpReluMTMTFP16(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
...@@ -317,7 +325,9 @@ class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(TestFuseGemmBase): ...@@ -317,7 +325,9 @@ class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
...@@ -378,7 +388,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(TestFuseGemmBase): ...@@ -378,7 +388,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
...@@ -438,7 +450,9 @@ class TestFuseGemmEpilogueOpGeluMMFP16(TestFuseGemmBase): ...@@ -438,7 +450,9 @@ class TestFuseGemmEpilogueOpGeluMMFP16(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
...@@ -494,7 +508,9 @@ class TestFuseGemmEpilogueOpNoneMMFP16(TestFuseGemmBase): ...@@ -494,7 +508,9 @@ class TestFuseGemmEpilogueOpNoneMMFP16(TestFuseGemmBase):
self.place self.place
): ):
return return
self.check_output_with_place(self.place, atol=self.atol) self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op") @skip_check_grad_ci(reason="no grap op")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册