提交 6461e800 编写于 作者: K Krzysztof Binias

Inheritance added for MKLDNN tests

上级 d8bd436f
......@@ -507,58 +507,46 @@ class TestSwish(OpTest):
#--------------------test MKLDNN--------------------
class TestMKLDNNRelu(OpTest):
class TestMKLDNNRelu(TestRelu):
def setUp(self):
self.op_type = "relu"
super(TestMKLDNNRelu, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Out': np.maximum(self.inputs['X'], 0)}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
self.check_output()
out = np.maximum(x, 0)
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
class TestMKLDNNTanh(OpTest):
class TestMKLDNNTanh(TestTanh):
def setUp(self):
self.op_type = "tanh"
super(TestMKLDNNTanh, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.tanh(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestMKLDNNSqrt(OpTest):
class TestMKLDNNSqrt(TestSqrt):
def setUp(self):
self.op_type = "sqrt"
super(TestMKLDNNSqrt, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.sqrt(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestMKLDNNAbs(OpTest):
class TestMKLDNNAbs(TestAbs):
def setUp(self):
self.op_type = "abs"
super(TestMKLDNNAbs, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
......@@ -566,12 +554,6 @@ class TestMKLDNNAbs(OpTest):
self.outputs = {'Out': np.abs(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册