From ded7d19025e9d8623d33ce9c5b2fac3a1be1e016 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 14 Jun 2023 11:06:08 +0800 Subject: [PATCH] [BugFix]: Fix ci test bugs in test_fuse_gemm_epilogue_pass.py and test_fused_gemm_epilogue_op.py (#54519) * fix ci bugs in fused_linear * fix code style --- .../test_fuse_gemm_epilogue_pass.py | 21 ++---------- .../test_fused_gemm_epilogue_op.py | 32 ++++++++++++++----- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/test/legacy_test/test_fuse_gemm_epilogue_pass.py b/test/legacy_test/test_fuse_gemm_epilogue_pass.py index 262c29b428e..13480e3d75d 100644 --- a/test/legacy_test/test_fuse_gemm_epilogue_pass.py +++ b/test/legacy_test/test_fuse_gemm_epilogue_pass.py @@ -135,7 +135,7 @@ class TestFuseGemmEpilogueFWDBase(unittest.TestCase): "_matmul_y": self.matmul_y_arr, "_ele_y": self.ele_y_arr, } - self.reference = self.exe.run( + self.reference = paddle.static.Executor(self.place).run( self.main_prog, feed=self.feed, fetch_list=[self.loss.name] ) @@ -206,10 +206,6 @@ class TestFuseGemmEpilogueReluFWDFP16(TestFuseGemmEpilogueReluFWDFP32): self.place, self.main_prog, to_fp16_var_names=fp16_var_list ) - self.data_arr = self.data_arr.astype("float16") - self.matmul_y_arr = self.matmul_y_arr.astype("float16") - self.ele_y_arr = self.ele_y_arr.astype("float16") - @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" @@ -239,10 +235,6 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32): self.place, self.main_prog, to_fp16_var_names=fp16_var_list ) - self.data_arr = self.data_arr.astype("float16") - self.matmul_y_arr = self.matmul_y_arr.astype("float16") - self.ele_y_arr = self.ele_y_arr.astype("float16") - @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" @@ -318,7 +310,8 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase): f'{multi_layer.linear3.full_name()}.w_0@GRAD', f'{multi_layer.linear3.full_name()}.b_0@GRAD', ] - self.outs_ref = self.exe.run( + + self.outs_ref = paddle.static.Executor(self.place).run( self.main_prog, feed=self.feed, fetch_list=self.fetch ) @@ -402,10 +395,6 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32): self.place, self.main_prog, to_fp16_var_names=fp16_var_list ) - self.data_arr = self.data_arr.astype("float16") - self.matmul_y_arr = self.matmul_y_arr.astype("float16") - self.ele_y_arr = self.ele_y_arr.astype("float16") - @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" @@ -435,10 +424,6 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32): self.place, self.main_prog, to_fp16_var_names=fp16_var_list ) - self.data_arr = self.data_arr.astype("float16") - self.matmul_y_arr = self.matmul_y_arr.astype("float16") - self.ele_y_arr = self.ele_y_arr.astype("float16") - if __name__ == "__main__": np.random.seed(0) diff --git a/test/legacy_test/test_fused_gemm_epilogue_op.py b/test/legacy_test/test_fused_gemm_epilogue_op.py index 08047d1f110..49095020279 100644 --- a/test/legacy_test/test_fused_gemm_epilogue_op.py +++ b/test/legacy_test/test_fused_gemm_epilogue_op.py @@ -89,7 +89,9 @@ class TestFuseGemmEpilogueOpReluMMFP16(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") @@ -146,7 +148,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") @@ -203,7 +207,9 @@ class TestFuseGemmEpilogueOpReluMMTFP16(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") @@ -260,7 +266,9 @@ class TestFuseGemmEpilogueOpReluMTMTFP16(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") @@ -317,7 +325,9 @@ class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") @@ -378,7 +388,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") @@ -438,7 +450,9 @@ class TestFuseGemmEpilogueOpGeluMMFP16(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") @@ -494,7 +508,9 @@ class TestFuseGemmEpilogueOpNoneMMFP16(TestFuseGemmBase): self.place ): return - self.check_output_with_place(self.place, atol=self.atol) + self.check_output_with_place( + self.place, atol=self.atol, check_dygraph=False + ) @skip_check_grad_ci(reason="no grap op") -- GitLab