未验证 提交 7d3b08d9 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

refactor mean op, *test=kunlun (#44000)

* refactor mean op, *test=kunlun

* refactor mean op, *test=kunlun

* refactor mean op,*test=kunlun

* refactor mean op,*test=kunlun
上级 75c975f0
...@@ -16,9 +16,6 @@ if(WITH_XPU_BKCL) ...@@ -16,9 +16,6 @@ if(WITH_XPU_BKCL)
list(APPEND DIST_TEST_OPS test_gen_bkcl_id_op) list(APPEND DIST_TEST_OPS test_gen_bkcl_id_op)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_concat_op_xpu)
list(REMOVE_ITEM TEST_OPS test_mean_op_xpu)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach() endforeach()
......
...@@ -28,29 +28,66 @@ from paddle.fluid import Program, program_guard ...@@ -28,29 +28,66 @@ from paddle.fluid import Program, program_guard
np.random.seed(10) np.random.seed(10)
import op_test
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestMeanOp(XPUOpTestWrapper):
class TestMeanOp(XPUOpTest): def __init__(self):
self.op_name = 'mean'
self.use_dynamic_create_class = False
class TestMeanOp(XPUOpTest):
def setUp(self): def setUp(self):
self.init_dtype()
self.set_xpu()
self.op_type = "mean" self.op_type = "mean"
self.init_dtype_type() self.place = paddle.XPUPlace(0)
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} self.set_shape()
self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"]).astype(np.float16)} self.outputs = {'Out': np.mean(self.inputs["X"]).astype(np.float16)}
def init_dtype_type(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = self.in_type
def set_shape(self):
self.shape = (10, 10)
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
self.__class__.op_type = self.dtype
def test_check_output(self): def test_check_output(self):
if paddle.is_compiled_with_xpu(): self.check_output_with_place(self.place)
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=2e-3)
def test_checkout_grad(self): def test_checkout_grad(self):
if paddle.is_compiled_with_xpu(): self.check_grad_with_place(self.place, ['X'], 'Out')
paddle.enable_static()
place = paddle.XPUPlace(0) class TestMeanOp1(TestMeanOp):
self.check_grad_with_place(place, ['X'], 'Out')
def set_shape(self):
self.shape = (5)
class TestMeanOp2(TestMeanOp):
def set_shape(self):
self.shape = (5, 7, 8)
class TestMeanOp3(TestMeanOp):
def set_shape(self):
self.shape = (10, 5, 7, 8)
class TestMeanOp4(TestMeanOp):
def set_shape(self):
self.shape = (2, 2, 3, 3, 3)
class TestMeanOpError(unittest.TestCase): class TestMeanOpError(unittest.TestCase):
...@@ -71,43 +108,9 @@ class TestMeanOpError(unittest.TestCase): ...@@ -71,43 +108,9 @@ class TestMeanOpError(unittest.TestCase):
fluid.layers.softmax(input3) fluid.layers.softmax(input3)
class TestXPUMeanOp(TestMeanOp): support_types = get_xpu_op_support_types('mean')
for stype in support_types:
def init_dtype_type(self): create_test_class(globals(), XPUTestMeanOp, stype)
self.dtype = np.float32
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_checkout_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class TestXPUMeanOpFp16(TestMeanOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_checkout_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'],
'Out',
max_relative_error=1.e1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册