未验证 提交 79b5db13 编写于 作者: Q QingshuChen 提交者: GitHub

bug: fix mul unitest bug (#27852)

*test=kunlun
上级 345574a6
......@@ -23,35 +23,9 @@ sys.path.append("..")
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import time
class TestMulOp(OpTest):
def setUp(self):
self.op_type = "mul"
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {
'X': np.random.random((20, 5)).astype(self.dtype),
'Y': np.random.random((5, 21)).astype(self.dtype)
}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
paddle.enable_static()
class TestMulOpError(unittest.TestCase):
......@@ -69,11 +43,13 @@ class TestMulOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.mul, x3, x4)
class TestMulOp2(OpTest):
class TestXPUMulOp1(OpTest):
def setUp(self):
self.op_type = "mul"
self.dtype = np.float64
self.dtype = np.float32
self.use_xpu = True
self.init_dtype_type()
np.random.seed((int)(time.time()))
self.inputs = {
'X': np.random.random((3, 4, 2, 9)).astype(self.dtype),
'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.dtype)
......@@ -94,67 +70,53 @@ class TestMulOp2(OpTest):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set('X'))
['Y'], 'Out', max_relative_error=0.1, no_grad_set=set('X'))
def test_check_grad_ignore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUMulOp1(TestMulOp):
def init_dtype_type(self):
class TestXPUMulOp2(OpTest):
def setUp(self):
self.op_type = "mul"
self.use_xpu = True
self.dtype = np.float32
self.init_dtype_type()
np.random.seed((int)(time.time()))
self.inputs = {
'X': np.random.random((20, 5)).astype(self.dtype),
'Y': np.random.random((5, 21)).astype(self.dtype)
}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-1)
def test_check_grad_normal(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.5)
def test_check_grad_ingore_x(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUMulOp2(TestMulOp2):
def init_dtype_type(self):
self.dtype = np.float32
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=2e-1)
self.check_output_with_place(place, atol=0.01)
def test_check_grad_normal(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.9)
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
place, ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.9, no_grad_set=set('Y'))
place, ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册