未验证 提交 a5d9f244 编写于 作者: Y yuehuayingxueluo 提交者: GitHub

fix ci bugs in fused_linear (#54605)

上级 b11f0b7a
......@@ -136,7 +136,7 @@ class TestFuseGemmEpilogueFWDBase(unittest.TestCase):
"_matmul_y": self.matmul_y_arr,
"_ele_y": self.ele_y_arr,
}
self.reference = self.exe.run(
self.reference = paddle.static.Executor(self.place).run(
self.main_prog, feed=self.feed, fetch_list=[self.loss.name]
)
......@@ -207,10 +207,6 @@ class TestFuseGemmEpilogueReluFWDFP16(TestFuseGemmEpilogueReluFWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list
)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......@@ -240,10 +236,6 @@ class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list
)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......@@ -319,7 +311,8 @@ class TestFuseGemmEpilogueBWDBase(unittest.TestCase):
f'{multi_layer.linear3.full_name()}.w_0@GRAD',
f'{multi_layer.linear3.full_name()}.b_0@GRAD',
]
self.outs_ref = self.exe.run(
self.outs_ref = paddle.static.Executor(self.place).run(
self.main_prog, feed=self.feed, fetch_list=self.fetch
)
......@@ -403,10 +396,6 @@ class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list
)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
......@@ -436,10 +425,6 @@ class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32):
self.place, self.main_prog, to_fp16_var_names=fp16_var_list
)
self.data_arr = self.data_arr.astype("float16")
self.matmul_y_arr = self.matmul_y_arr.astype("float16")
self.ele_y_arr = self.ele_y_arr.astype("float16")
if __name__ == "__main__":
np.random.seed(0)
......
......@@ -89,7 +89,9 @@ class TestFuseGemmEpilogueOpReluMMFP16(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......@@ -146,7 +148,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......@@ -203,7 +207,9 @@ class TestFuseGemmEpilogueOpReluMMTFP16(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......@@ -260,7 +266,9 @@ class TestFuseGemmEpilogueOpReluMTMTFP16(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......@@ -317,7 +325,9 @@ class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......@@ -378,7 +388,9 @@ class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......@@ -438,7 +450,9 @@ class TestFuseGemmEpilogueOpGeluMMFP16(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......@@ -494,7 +508,9 @@ class TestFuseGemmEpilogueOpNoneMMFP16(TestFuseGemmBase):
self.place
):
return
self.check_output_with_place(self.place, atol=self.atol)
self.check_output_with_place(
self.place, atol=self.atol, check_dygraph=False
)
@skip_check_grad_ci(reason="no grap op")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册